arlbench.core.algorithms.ppo¶
- class arlbench.core.algorithms.ppo.PPO(hpo_config, env, eval_env=None, cnn_policy=False, nas_config=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Proximal Policy Optimization (PPO).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all posssible checkpointing options for PPO.
- Parameters:
runner_state (PPORunnerState) – Algorithm runner state.
train_result (PPOTrainingResult | None) – Training result.
- Returns:
Dictionary of factory functions containing [opt_state, params, loss, trajectories].
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default hyperparameter configuration for PPO.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default neural architecture search configuration for PPO.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter configuration space for PPO.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search configuration space for PPO.
- Return type:
ConfigurationSpace
- init(rng, network_params=None, opt_state=None)[source]¶
Initializes PPO state. Passed parameters are not initialized and included in the final state.
- Parameters:
rng (chex.PRNGKey) – Random generator key.
network_params (FrozenDict | dict | None, optional) – Network parameters. Defaults to None.
opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.
- Returns:
PPO state.
- Return type:
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (PPORunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.
deterministic (bool) – Return deterministic action. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, _, n_total_timesteps=1000000, n_eval_steps=10, n_eval_episodes=10)[source]¶
Performs one iteration of training.
- Parameters:
runner_state (PPORunnerState) – PPO runner state.
_ (None) – Unused parameter (buffer_state in other algorithms).
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.
- Returns:
Tuple of PPO algorithm state and training result.
- Return type:
PPOTrainReturnT
- class arlbench.core.algorithms.ppo.PPOMetrics(loss: jnp.ndarray, grads: jnp.ndarray | dict, advantages: jnp.ndarray)[source]¶
Bases:
NamedTuple
PPO metrics returned by train function. Consists of (loss, grads, advantages).
-
advantages:
Array
¶ Alias for field number 2
-
grads:
Array
|dict
¶ Alias for field number 1
-
loss:
Array
¶ Alias for field number 0
-
advantages:
- class arlbench.core.algorithms.ppo.PPORunnerState(rng: chex.PRNGKey, train_state: PPOTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]¶
Bases:
NamedTuple
PPO runner state. Consists of (rng, train_state, env_state, obs, global_step).
-
env_state:
Any
¶ Alias for field number 3
-
global_step:
int
¶ Alias for field number 5
-
normalizer_state:
RunningStatisticsState
¶ Alias for field number 2
-
obs:
Union
[Array
,ndarray
,bool_
,number
]¶ Alias for field number 4
-
rng:
Array
¶ Alias for field number 0
-
train_state:
PPOTrainState
¶ Alias for field number 1
-
env_state:
- class arlbench.core.algorithms.ppo.PPOState(runner_state: PPORunnerState, buffer_state: None = None)[source]¶
Bases:
NamedTuple
PPO algorithm state. Consists of (runner_state, buffer_state).
Note: As PPO does not use a buffer buffer_state is always None and only kept for consistency across algorithms.
-
buffer_state:
None
¶ Alias for field number 1
-
runner_state:
PPORunnerState
¶ Alias for field number 0
-
buffer_state:
- arlbench.core.algorithms.ppo.PPOTrainReturnT¶
alias of
tuple
[PPOState
,PPOTrainingResult
]
- class arlbench.core.algorithms.ppo.PPOTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: PPOMetrics | None)[source]¶
Bases:
NamedTuple
PPO training result. Consists of (eval_rewards, trajectories, metrics).
-
eval_rewards:
Array
¶ Alias for field number 0
-
metrics:
PPOMetrics
|None
¶ Alias for field number 2
-
trajectories:
Transition
|None
¶ Alias for field number 1
-
eval_rewards:
Modules
Models for the PPO algorithm. |
|
Proximal Policy Optimization (PPO) algorithm. |