arlbench.core.algorithms package

Subpackages

Submodules

arlbench.core.algorithms.algorithm module

Abstract base class for a reinforcement learning algorithm. Contains basic functionality that is shared among different algorithm implementations.

class arlbench.core.algorithms.algorithm.Algorithm(hpo_config, nas_config, env, eval_env=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]

Bases: ABC

Abstract base class for a reinforcement learning algorithm. Contains basic functionality that is shared among different algorithm implementations.

property action_type: tuple[int, bool]

The size and type of actions of the algorithm/environment.

Returns:

Tuple of (action_size, discrete). action_size is the number of possible actions and discrete defines whether the action space is discrete or not.

Return type:

tuple[int, bool]

eval(runner_state, num_eval_episodes)[source]

Evaluate the algorithm.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • num_eval_episodes (int) – Number of evaluation episodes.

Returns:

Cumulative reward per evaluation episodes. Shape: (n_eval_episodes,).

Return type:

jnp.ndarray

abstract static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for Algorithm.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • train_result (Any) – Training result.

Returns:

Dictionary of factory functions.

Return type:

dict[str, Callable]

abstract static get_default_hpo_config()[source]

Returns the default hyperparameter configuration of the agent.

Returns:

Default hyperparameter configuration.

Return type:

Configuration

abstract static get_default_nas_config()[source]

Returns the default neural architecture configuration of the agent.

Returns:

Default neural architecture configuration.

Return type:

Configuration

abstract static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space of the algorithm.

Parameters:

seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.

Returns:

Hyperparameter configuration space of the algorithm.

Return type:

ConfigurationSpace

abstract static get_nas_config_space(seed=None)[source]

Returns the neural architecture configuration space of the algorithm.

Parameters:

seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.

Returns:

Neural architecture configuration space of the algorithm.

Return type:

ConfigurationSpace

abstract init(rng)[source]

Initializes the algorithm state. Passed parameters are not initialized and included in the final state.

Parameters:

rng (chex.PRNGKey) – Random generator key.

Returns:

Algorithm state.

Return type:

Any

name: str
abstract predict(runner_state, obs, rng=None, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

abstract train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one iteration of training.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • buffer_state (Any) – Algorithm buffer state.

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.

Returns:

(algorithm_state, training_result).

Return type:

tuple[Any, Any]

update_hpo_config(hpo_config)[source]

Update the hyperparameter configuration of the algorithm.

Parameters:

hpo_config (Configuration) – Hyperparameter configuration.

arlbench.core.algorithms.buffers module

Replay buffers.

arlbench.core.algorithms.buffers.uniform_sample(state, rng_key, batch_size, sequence_length, period)[source]

Adapted sample function to support uniform sampling for priorizized buffers.

Parameters:
  • state (PrioritisedTrajectoryBufferState[Experience]) – Buffer state.

  • rng_key (chex.PRNGKey) – Random generator key.

  • batch_size (int) – Sample batch size.

  • sequence_length (int) – Length of trajectory to sample.

  • period (int) – Interval between sampled sequences.

Returns:

Batch of experience.

Return type:

TransitionSample

arlbench.core.algorithms.common module

Common data structures for algorithms.

class arlbench.core.algorithms.common.TimeStep(last_obs, obs, action, reward, done)[source]

Bases: Mapping

A timestep capturing an environment interaction.

action: Union[Array, ndarray, bool_, number]
done: Union[Array, ndarray, bool_, number]
from_tuple()
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
last_obs: Union[Array, ndarray, bool_, number]
obs: Union[Array, ndarray, bool_, number]
replace(**kwargs)
reward: Union[Array, ndarray, bool_, number]
to_tuple()
values() an object providing a view on D's values

arlbench.core.algorithms.prioritised_item_buffer module

Prioritised replay buffer.

arlbench.core.algorithms.prioritised_item_buffer.create_prioritised_item_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batches, priority_exponent, device)[source]

Creates a prioritised trajectory buffer that acts as an independent item buffer.

Parameters:
  • max_length (int) – The maximum length of the buffer.

  • min_length (int) – The minimum length of the buffer.

  • sample_batch_size (int) – The batch size of the samples.

  • add_sequences (Optional[bool], optional) – Whether data is being added in sequences to the buffer. If False, single items are being added each time add is called. Defaults to False.

  • add_batches (bool) – (Optional[bool], optional): Whether adding data in batches to the buffer. If False, single items (or single sequences of items) are being added each time add is called. Defaults to False.

  • priority_exponent (float) – Priority exponent for sampling. Equivalent to alpha in the PER paper.

  • device (str) – “tpu”, “gpu” or “cpu”. Depending on chosen device, more optimal functions will be used to perform the buffer operations.

Return type:

PrioritisedTrajectoryBuffer

Returns: The buffer.

arlbench.core.algorithms.prioritised_item_buffer.make_prioritised_item_buffer(max_length, min_length, sample_batch_size, add_sequences=False, add_batches=False, priority_exponent=0.6, device='cpu')[source]

Makes a prioritised trajectory buffer act as a independent item buffer.

Parameters:
  • max_length (int) – The maximum length of the buffer.

  • min_length (int) – The minimum length of the buffer.

  • sample_batch_size (int) – The batch size of the samples.

  • add_sequences (Optional[bool], optional) – Whether data is being added in sequences to the buffer. If False, single items are being added each time add is called. Defaults to False.

  • add_batches (bool) – (Optional[bool], optional): Whether adding data in batches to the buffer. If False, single transitions or single sequences are being added each time add is called. Defaults to False.

  • priority_exponent (float) – Priority exponent for sampling. Equivalent to alpha in the PER paper.

  • device (str) – “tpu”, “gpu” or “cpu”. Depending on chosen device, more optimal functions will be used to perform the buffer operations.

Return type:

PrioritisedTrajectoryBuffer

Returns: The buffer.

Module contents

RL algorithms.

class arlbench.core.algorithms.Algorithm(hpo_config, nas_config, env, eval_env=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]

Bases: ABC

Abstract base class for a reinforcement learning algorithm. Contains basic functionality that is shared among different algorithm implementations.

property action_type: tuple[int, bool]

The size and type of actions of the algorithm/environment.

Returns:

Tuple of (action_size, discrete). action_size is the number of possible actions and discrete defines whether the action space is discrete or not.

Return type:

tuple[int, bool]

eval(runner_state, num_eval_episodes)[source]

Evaluate the algorithm.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • num_eval_episodes (int) – Number of evaluation episodes.

Returns:

Cumulative reward per evaluation episodes. Shape: (n_eval_episodes,).

Return type:

jnp.ndarray

abstract static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for Algorithm.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • train_result (Any) – Training result.

Returns:

Dictionary of factory functions.

Return type:

dict[str, Callable]

abstract static get_default_hpo_config()[source]

Returns the default hyperparameter configuration of the agent.

Returns:

Default hyperparameter configuration.

Return type:

Configuration

abstract static get_default_nas_config()[source]

Returns the default neural architecture configuration of the agent.

Returns:

Default neural architecture configuration.

Return type:

Configuration

abstract static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space of the algorithm.

Parameters:

seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.

Returns:

Hyperparameter configuration space of the algorithm.

Return type:

ConfigurationSpace

abstract static get_nas_config_space(seed=None)[source]

Returns the neural architecture configuration space of the algorithm.

Parameters:

seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.

Returns:

Neural architecture configuration space of the algorithm.

Return type:

ConfigurationSpace

abstract init(rng)[source]

Initializes the algorithm state. Passed parameters are not initialized and included in the final state.

Parameters:

rng (chex.PRNGKey) – Random generator key.

Returns:

Algorithm state.

Return type:

Any

name: str
abstract predict(runner_state, obs, rng=None, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

abstract train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one iteration of training.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • buffer_state (Any) – Algorithm buffer state.

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.

Returns:

(algorithm_state, training_result).

Return type:

tuple[Any, Any]

update_hpo_config(hpo_config)[source]

Update the hyperparameter configuration of the algorithm.

Parameters:

hpo_config (Configuration) – Hyperparameter configuration.

class arlbench.core.algorithms.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]

Bases: Algorithm

JAX-based implementation of Deep Q-Network (DQN).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all possible checkpointing options for DQN.

Parameters:
Returns:

Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories].

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default hyperparameter configuration for DQN.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default NAS configuration for DQN.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter optimization (HPO) configuration space for DQN.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search (NAS) configuration space for DQN.

Return type:

ConfigurationSpace

init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]

Initializes DQN state. Passed parameters are not initialized and included in the final state.

Parameters:
  • rng (chex.PRNGKey) – Random generator key.

  • buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None.

  • network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None.

  • target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None.

  • opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.

Returns:

DQN state.

Return type:

DQNState

name: str = 'dqn'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (DQNRunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one full training.

Parameters:
  • runner_state (DQNRunnerState) – DQN runner state.

  • buffer_state (PrioritisedTrajectoryBufferState) – Buffer state.

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.

Returns:

Tuple of DQN algorithm state and training result.

Return type:

DQNTrainReturnT

update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]

Update the Q-network.

Parameters:
  • train_state (DQNTrainState) – DQN training state.

  • observations (jnp.ndarray) – Batch of observations.

  • actions (jnp.ndarray) – Batch of actions.

  • next_observations (jnp.ndarray) – Batch of next observations.

  • rewards (jnp.ndarray) – Batch of rewards.

  • dones (jnp.ndarray) – Batch of dones.

Returns:

Tuple of (train_state, loss, td_error, grads).

Return type:

tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]

class arlbench.core.algorithms.PPO(hpo_config, env, eval_env=None, cnn_policy=False, nas_config=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]

Bases: Algorithm

JAX-based implementation of Proximal Policy Optimization (PPO).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for PPO.

Parameters:
Returns:

Dictionary of factory functions containing [opt_state, params, loss, trajectories].

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default hyperparameter configuration for PPO.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default neural architecture search configuration for PPO.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space for PPO.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search configuration space for PPO.

Return type:

ConfigurationSpace

init(rng, network_params=None, opt_state=None)[source]

Initializes PPO state. Passed parameters are not initialized and included in the final state.

Parameters:
  • rng (chex.PRNGKey) – Random generator key.

  • network_params (FrozenDict | dict | None, optional) – Network parameters. Defaults to None.

  • opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.

Returns:

PPO state.

Return type:

PPOState

name: str = 'ppo'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (PPORunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, _, n_total_timesteps=1000000, n_eval_steps=10, n_eval_episodes=10)[source]

Performs one iteration of training.

Parameters:
  • runner_state (PPORunnerState) – PPO runner state.

  • _ (None) – Unused parameter (buffer_state in other algorithms).

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.

Returns:

Tuple of PPO algorithm state and training result.

Return type:

PPOTrainReturnT

class arlbench.core.algorithms.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]

Bases: Algorithm

JAX-based implementation of Soft-Actor-Critic (SAC).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for SAC.

Parameters:
Returns:

Dictionary of factory functions containing:
  • actor_opt_state

  • critic_opt_state

  • alpha_opt_state

  • actor_network_params

  • critic_network_params

  • critic_target_params

  • alpha_network_params

  • actor_loss

  • critic_loss

  • alpha_loss

  • trajectories

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default HPO configuration for SAC.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default NAS configuration for SAC.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space for SAC.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search (NAS) configuration space for SAC.

Return type:

ConfigurationSpace

init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]

Initializes SAC state. Passed parameters are not initialized and included in the final state.

Parameters:
  • actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None.

  • critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None.

  • critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None.

  • alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None.

  • actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None.

  • critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None.

  • alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None.

Returns:

SAC state.

Return type:

SACState

name: str = 'sac'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (SACRunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None.

  • deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one full training.

Parameters:
  • runner_state (SACTrainReturnT) – SAC runner state.

  • _ (None) – Unused parameter (buffer_state in other algorithms).

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.

Returns:

Tuple of PPO algorithm state and training result.

Return type:

SACTrainReturnT

update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]

Updates the actor network parameters.

Parameters:
  • actor_train_state (SACTrainState) – Actor train state.

  • critic_train_state (SACTrainState) – Critic train state.

  • alpha_train_state (SACTrainState) – Alpha train state.

  • experience (TimeStep) – Experience (batch of TimeSteps).

  • is_weights (jnp.ndarray) – Whether to use weights for PER or not.

  • rng (chex.PRNGKey) – Random number generator key.

Returns:

_description_

Return type:

tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]

update_alpha(alpha_train_state, entropy)[source]

Update alpha network parameters.

Parameters:
  • alpha_train_state (SACTrainState) – Alpha training state.

  • entropy (jnp.ndarray) – Entropy values.

Returns:

Updated trainingi state and metrics.

Return type:

tuple[SACTrainState, jnp.ndarray]

update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]

Updates the critic network parameters.

Parameters:
  • actor_train_state (SACTrainState) – Actor train state.

  • critic_train_state (SACTrainState) – Critic train state.

  • alpha_train_state (SACTrainState) – Alpha train state.

  • experience (Transition) – Experience (batch of transitions).

  • rng (chex.PRNGKey) – Random number generator key.

Returns:

Updated training state and metrics.

Return type:

tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]