arlbench.core.algorithms.ppo

class arlbench.core.algorithms.ppo.PPO(hpo_config, env, eval_env=None, cnn_policy=False, nas_config=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]

Bases: Algorithm

JAX-based implementation of Proximal Policy Optimization (PPO).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for PPO.

Parameters:
Returns:

Dictionary of factory functions containing [opt_state, params, loss, trajectories].

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default hyperparameter configuration for PPO.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default neural architecture search configuration for PPO.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space for PPO.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search configuration space for PPO.

Return type:

ConfigurationSpace

init(rng, network_params=None, opt_state=None)[source]

Initializes PPO state. Passed parameters are not initialized and included in the final state.

Parameters:
  • rng (chex.PRNGKey) – Random generator key.

  • network_params (FrozenDict | dict | None, optional) – Network parameters. Defaults to None.

  • opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.

Returns:

PPO state.

Return type:

PPOState

predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (PPORunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, _, n_total_timesteps=1000000, n_eval_steps=10, n_eval_episodes=10)[source]

Performs one iteration of training.

Parameters:
  • runner_state (PPORunnerState) – PPO runner state.

  • _ (None) – Unused parameter (buffer_state in other algorithms).

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.

Returns:

Tuple of PPO algorithm state and training result.

Return type:

PPOTrainReturnT

class arlbench.core.algorithms.ppo.PPOMetrics(loss: jnp.ndarray, grads: jnp.ndarray | dict, advantages: jnp.ndarray)[source]

Bases: NamedTuple

PPO metrics returned by train function. Consists of (loss, grads, advantages).

advantages: Array

Alias for field number 2

grads: Array | dict

Alias for field number 1

loss: Array

Alias for field number 0

class arlbench.core.algorithms.ppo.PPORunnerState(rng: chex.PRNGKey, train_state: PPOTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]

Bases: NamedTuple

PPO runner state. Consists of (rng, train_state, env_state, obs, global_step).

env_state: Any

Alias for field number 3

global_step: int

Alias for field number 5

normalizer_state: RunningStatisticsState

Alias for field number 2

obs: Union[Array, ndarray, bool_, number]

Alias for field number 4

rng: Array

Alias for field number 0

train_state: PPOTrainState

Alias for field number 1

class arlbench.core.algorithms.ppo.PPOState(runner_state: PPORunnerState, buffer_state: None = None)[source]

Bases: NamedTuple

PPO algorithm state. Consists of (runner_state, buffer_state).

Note: As PPO does not use a buffer buffer_state is always None and only kept for consistency across algorithms.

buffer_state: None

Alias for field number 1

runner_state: PPORunnerState

Alias for field number 0

arlbench.core.algorithms.ppo.PPOTrainReturnT

alias of tuple[PPOState, PPOTrainingResult]

class arlbench.core.algorithms.ppo.PPOTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: PPOMetrics | None)[source]

Bases: NamedTuple

PPO training result. Consists of (eval_rewards, trajectories, metrics).

eval_rewards: Array

Alias for field number 0

metrics: PPOMetrics | None

Alias for field number 2

trajectories: Transition | None

Alias for field number 1

Modules

models

Models for the PPO algorithm.

ppo

Proximal Policy Optimization (PPO) algorithm.