arlbench.core.algorithms¶
RL algorithms.
- class arlbench.core.algorithms.Algorithm(hpo_config, nas_config, env, eval_env=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]¶
Bases:
ABC
Abstract base class for a reinforcement learning algorithm. Contains basic functionality that is shared among different algorithm implementations.
- property action_type: tuple[int, bool]¶
The size and type of actions of the algorithm/environment.
- Returns:
Tuple of (action_size, discrete). action_size is the number of possible actions and discrete defines whether the action space is discrete or not.
- Return type:
tuple[int, bool]
- eval(runner_state, num_eval_episodes)[source]¶
Evaluate the algorithm.
- Parameters:
runner_state (Any) – Algorithm runner state.
num_eval_episodes (int) – Number of evaluation episodes.
- Returns:
Cumulative reward per evaluation episodes. Shape: (n_eval_episodes,).
- Return type:
jnp.ndarray
- abstract static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all posssible checkpointing options for Algorithm.
- Parameters:
runner_state (Any) – Algorithm runner state.
train_result (Any) – Training result.
- Returns:
Dictionary of factory functions.
- Return type:
dict[str, Callable]
- abstract static get_default_hpo_config()[source]¶
Returns the default hyperparameter configuration of the agent.
- Returns:
Default hyperparameter configuration.
- Return type:
Configuration
- abstract static get_default_nas_config()[source]¶
Returns the default neural architecture configuration of the agent.
- Returns:
Default neural architecture configuration.
- Return type:
Configuration
- abstract static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter configuration space of the algorithm.
- Parameters:
seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.
- Returns:
Hyperparameter configuration space of the algorithm.
- Return type:
ConfigurationSpace
- abstract static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture configuration space of the algorithm.
- Parameters:
seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.
- Returns:
Neural architecture configuration space of the algorithm.
- Return type:
ConfigurationSpace
- abstract init(rng)[source]¶
Initializes the algorithm state. Passed parameters are not initialized and included in the final state.
- Parameters:
rng (chex.PRNGKey) – Random generator key.
- Returns:
Algorithm state.
- Return type:
Any
- abstract predict(runner_state, obs, rng=None, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (Any) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.
deterministic (bool) – Return deterministic action. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- abstract train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
Performs one iteration of training.
- Parameters:
runner_state (Any) – Algorithm runner state.
buffer_state (Any) – Algorithm buffer state.
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.
- Returns:
(algorithm_state, training_result).
- Return type:
tuple[Any, Any]
- class arlbench.core.algorithms.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Deep Q-Network (DQN).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all possible checkpointing options for DQN.
- Parameters:
runner_state (DQNRunnerState) – Algorithm runner state.
train_result (DQNTrainingResult | None) – Training result.
- Returns:
Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories].
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default hyperparameter configuration for DQN.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default NAS configuration for DQN.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter optimization (HPO) configuration space for DQN.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search (NAS) configuration space for DQN.
- Return type:
ConfigurationSpace
- init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]¶
Initializes DQN state. Passed parameters are not initialized and included in the final state.
- Parameters:
rng (chex.PRNGKey) – Random generator key.
buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None.
network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None.
target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None.
opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.
- Returns:
DQN state.
- Return type:
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (DQNRunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None.
deterministic (bool) – Return deterministic action. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
Performs one full training.
- Parameters:
runner_state (DQNRunnerState) – DQN runner state.
buffer_state (PrioritisedTrajectoryBufferState) – Buffer state.
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.
- Returns:
Tuple of DQN algorithm state and training result.
- Return type:
DQNTrainReturnT
- update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]¶
Update the Q-network.
- Parameters:
train_state (DQNTrainState) – DQN training state.
observations (jnp.ndarray) – Batch of observations.
actions (jnp.ndarray) – Batch of actions.
next_observations (jnp.ndarray) – Batch of next observations.
rewards (jnp.ndarray) – Batch of rewards.
dones (jnp.ndarray) – Batch of dones.
- Returns:
Tuple of (train_state, loss, td_error, grads).
- Return type:
tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]
- class arlbench.core.algorithms.PPO(hpo_config, env, eval_env=None, cnn_policy=False, nas_config=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Proximal Policy Optimization (PPO).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all posssible checkpointing options for PPO.
- Parameters:
runner_state (PPORunnerState) – Algorithm runner state.
train_result (PPOTrainingResult | None) – Training result.
- Returns:
Dictionary of factory functions containing [opt_state, params, loss, trajectories].
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default hyperparameter configuration for PPO.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default neural architecture search configuration for PPO.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter configuration space for PPO.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search configuration space for PPO.
- Return type:
ConfigurationSpace
- init(rng, network_params=None, opt_state=None)[source]¶
Initializes PPO state. Passed parameters are not initialized and included in the final state.
- Parameters:
rng (chex.PRNGKey) – Random generator key.
network_params (FrozenDict | dict | None, optional) – Network parameters. Defaults to None.
opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.
- Returns:
PPO state.
- Return type:
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (PPORunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.
deterministic (bool) – Return deterministic action. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, _, n_total_timesteps=1000000, n_eval_steps=10, n_eval_episodes=10)[source]¶
Performs one iteration of training.
- Parameters:
runner_state (PPORunnerState) – PPO runner state.
_ (None) – Unused parameter (buffer_state in other algorithms).
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.
- Returns:
Tuple of PPO algorithm state and training result.
- Return type:
PPOTrainReturnT
- class arlbench.core.algorithms.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Soft-Actor-Critic (SAC).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all posssible checkpointing options for SAC.
- Parameters:
runner_state (SACRunnerState) – Algorithm runner state.
train_result (SACTrainingResult | None) – Training result.
- Returns:
- Dictionary of factory functions containing:
actor_opt_state
critic_opt_state
alpha_opt_state
actor_network_params
critic_network_params
critic_target_params
alpha_network_params
actor_loss
critic_loss
alpha_loss
trajectories
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default HPO configuration for SAC.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default NAS configuration for SAC.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter configuration space for SAC.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search (NAS) configuration space for SAC.
- Return type:
ConfigurationSpace
- init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]¶
Initializes SAC state. Passed parameters are not initialized and included in the final state.
- Parameters:
actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None.
critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None.
critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None.
alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None.
actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None.
critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None.
alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None.
- Returns:
SAC state.
- Return type:
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (SACRunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None.
deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
Performs one full training.
- Parameters:
runner_state (SACTrainReturnT) – SAC runner state.
_ (None) – Unused parameter (buffer_state in other algorithms).
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.
- Returns:
Tuple of PPO algorithm state and training result.
- Return type:
SACTrainReturnT
- update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
Updates the actor network parameters.
- Parameters:
actor_train_state (SACTrainState) – Actor train state.
critic_train_state (SACTrainState) – Critic train state.
alpha_train_state (SACTrainState) – Alpha train state.
experience (TimeStep) – Experience (batch of TimeSteps).
is_weights (jnp.ndarray) – Whether to use weights for PER or not.
rng (chex.PRNGKey) – Random number generator key.
- Returns:
_description_
- Return type:
tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]
- update_alpha(alpha_train_state, entropy)[source]¶
Update alpha network parameters.
- Parameters:
alpha_train_state (SACTrainState) – Alpha training state.
entropy (jnp.ndarray) – Entropy values.
- Returns:
Updated trainingi state and metrics.
- Return type:
tuple[SACTrainState, jnp.ndarray]
- update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
Updates the critic network parameters.
- Parameters:
actor_train_state (SACTrainState) – Actor train state.
critic_train_state (SACTrainState) – Critic train state.
alpha_train_state (SACTrainState) – Alpha train state.
experience (Transition) – Experience (batch of transitions).
rng (chex.PRNGKey) – Random number generator key.
- Returns:
Updated training state and metrics.
- Return type:
tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]
Modules
Abstract base class for a reinforcement learning algorithm. |
|
Replay buffers. |
|
Common data structures for algorithms. |
|
Prioritised replay buffer. |
|