arlbench.core.algorithms.algorithm

Abstract base class for a reinforcement learning algorithm. Contains basic functionality that is shared among different algorithm implementations.

Classes

Algorithm(hpo_config, nas_config, env[, ...])

Abstract base class for a reinforcement learning algorithm.

class arlbench.core.algorithms.algorithm.Algorithm(hpo_config, nas_config, env, eval_env=None, deterministic_eval=True, track_trajectories=False, track_metrics=False)[source]

Bases: ABC

Abstract base class for a reinforcement learning algorithm. Contains basic functionality that is shared among different algorithm implementations.

property action_type: tuple[int, bool]

The size and type of actions of the algorithm/environment.

Returns:

Tuple of (action_size, discrete). action_size is the number of possible actions and discrete defines whether the action space is discrete or not.

Return type:

tuple[int, bool]

eval(runner_state, num_eval_episodes)[source]

Evaluate the algorithm.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • num_eval_episodes (int) – Number of evaluation episodes.

Returns:

Cumulative reward per evaluation episodes. Shape: (n_eval_episodes,).

Return type:

jnp.ndarray

abstract static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for Algorithm.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • train_result (Any) – Training result.

Returns:

Dictionary of factory functions.

Return type:

dict[str, Callable]

abstract static get_default_hpo_config()[source]

Returns the default hyperparameter configuration of the agent.

Returns:

Default hyperparameter configuration.

Return type:

Configuration

abstract static get_default_nas_config()[source]

Returns the default neural architecture configuration of the agent.

Returns:

Default neural architecture configuration.

Return type:

Configuration

abstract static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space of the algorithm.

Parameters:

seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.

Returns:

Hyperparameter configuration space of the algorithm.

Return type:

ConfigurationSpace

abstract static get_nas_config_space(seed=None)[source]

Returns the neural architecture configuration space of the algorithm.

Parameters:

seed (int | None, optional) – Random generator seed that is used to sample configurations. Defaults to None.

Returns:

Neural architecture configuration space of the algorithm.

Return type:

ConfigurationSpace

abstract init(rng)[source]

Initializes the algorithm state. Passed parameters are not initialized and included in the final state.

Parameters:

rng (chex.PRNGKey) – Random generator key.

Returns:

Algorithm state.

Return type:

Any

abstract predict(runner_state, obs, rng=None, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Random generator key. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

abstract train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one iteration of training.

Parameters:
  • runner_state (Any) – Algorithm runner state.

  • buffer_state (Any) – Algorithm buffer state.

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.

Returns:

(algorithm_state, training_result).

Return type:

tuple[Any, Any]

update_hpo_config(hpo_config)[source]

Update the hyperparameter configuration of the algorithm.

Parameters:

hpo_config (Configuration) – Hyperparameter configuration.