arlbench.core.algorithms.dqn package

Submodules

arlbench.core.algorithms.dqn.dqn module

DQN algorithm.

class arlbench.core.algorithms.dqn.dqn.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]

Bases: Algorithm

JAX-based implementation of Deep Q-Network (DQN).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all possible checkpointing options for DQN.

Parameters:
Returns:

Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories].

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default hyperparameter configuration for DQN.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default NAS configuration for DQN.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter optimization (HPO) configuration space for DQN.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search (NAS) configuration space for DQN.

Return type:

ConfigurationSpace

init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]

Initializes DQN state. Passed parameters are not initialized and included in the final state.

Parameters:
  • rng (chex.PRNGKey) – Random generator key.

  • buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None.

  • network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None.

  • target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None.

  • opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.

Returns:

DQN state.

Return type:

DQNState

name: str = 'dqn'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (DQNRunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one full training.

Parameters:
  • runner_state (DQNRunnerState) – DQN runner state.

  • buffer_state (PrioritisedTrajectoryBufferState) – Buffer state.

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.

Returns:

Tuple of DQN algorithm state and training result.

Return type:

DQNTrainReturnT

update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]

Update the Q-network.

Parameters:
  • train_state (DQNTrainState) – DQN training state.

  • observations (jnp.ndarray) – Batch of observations.

  • actions (jnp.ndarray) – Batch of actions.

  • next_observations (jnp.ndarray) – Batch of next observations.

  • rewards (jnp.ndarray) – Batch of rewards.

  • dones (jnp.ndarray) – Batch of dones.

Returns:

Tuple of (train_state, loss, td_error, grads).

Return type:

tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]

class arlbench.core.algorithms.dqn.dqn.DQNMetrics(loss: jnp.ndarray, grads: jnp.ndarray | tuple, td_error: jnp.ndarray)[source]

Bases: NamedTuple

DQN metrics returned by train function. Consists of (loss, grads, td_error).

grads: Array | tuple

Alias for field number 1

loss: Array

Alias for field number 0

td_error: Array

Alias for field number 2

class arlbench.core.algorithms.dqn.dqn.DQNRunnerState(rng: chex.PRNGKey, train_state: DQNTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: jnp.ndarray, global_step: int)[source]

Bases: NamedTuple

DQN runner state. Consists of (rng, train_state, env_state, obs, global_step).

env_state: Any

Alias for field number 3

global_step: int

Alias for field number 5

normalizer_state: RunningStatisticsState

Alias for field number 2

obs: Array

Alias for field number 4

rng: Array

Alias for field number 0

train_state: DQNTrainState

Alias for field number 1

class arlbench.core.algorithms.dqn.dqn.DQNState(runner_state: DQNRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]

Bases: NamedTuple

DQN algorithm state. Consists of (runner_state, buffer_state).

buffer_state: PrioritisedTrajectoryBufferState

Alias for field number 1

runner_state: DQNRunnerState

Alias for field number 0

class arlbench.core.algorithms.dqn.dqn.DQNTrainState(step, apply_fn, params, tx, opt_state, target_params=None)[source]

Bases: TrainState

DQN training state.

classmethod create_with_opt_state(*, apply_fn, params, target_params, tx, opt_state, **kwargs)[source]

Creates a DQN training state with the given optimizer state.

opt_state: optax.OptState
replace(**updates)

“Returns a new object replacing the specified fields with new values.

target_params: None | chex.Array | dict | FrozenDict = None
class arlbench.core.algorithms.dqn.dqn.DQNTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: DQNMetrics | None)[source]

Bases: NamedTuple

DQN training result. Consists of (eval_rewards, trajectories, metrics).

eval_rewards: Array

Alias for field number 0

metrics: DQNMetrics | None

Alias for field number 2

trajectories: Transition | None

Alias for field number 1

class arlbench.core.algorithms.dqn.dqn.Transition(done: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray, obs: jnp.ndarray, info: dict)[source]

Bases: NamedTuple

DQN Transition. Consists of (done, action, reward, obs, info).

action: Array

Alias for field number 1

done: Array

Alias for field number 0

info: dict

Alias for field number 4

obs: Array

Alias for field number 3

reward: Array

Alias for field number 2

arlbench.core.algorithms.dqn.models module

Q-Networks for DQN.

class arlbench.core.algorithms.dqn.models.CNNQ(action_dim, activation='tanh', hidden_size=512, discrete=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A CNN-based Q-Network for DQN.

__call__(x)[source]

Applies the CNN to the input.

action_dim: int
activation: str = 'tanh'
discrete: bool = True
hidden_size: int = 512
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the CNN Q-Network.

class arlbench.core.algorithms.dqn.models.MLPQ(action_dim, activation='tanh', hidden_size=64, discrete=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

An MLP-based Q-Network for DQN.

__call__(x)[source]

Applies the MLP to the input.

action_dim: int
activation: str = 'tanh'
discrete: bool = True
hidden_size: int = 64
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the MLP Q-Network.

Module contents

class arlbench.core.algorithms.dqn.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]

Bases: Algorithm

JAX-based implementation of Deep Q-Network (DQN).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all possible checkpointing options for DQN.

Parameters:
Returns:

Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories].

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default hyperparameter configuration for DQN.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default NAS configuration for DQN.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter optimization (HPO) configuration space for DQN.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search (NAS) configuration space for DQN.

Return type:

ConfigurationSpace

init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]

Initializes DQN state. Passed parameters are not initialized and included in the final state.

Parameters:
  • rng (chex.PRNGKey) – Random generator key.

  • buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None.

  • network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None.

  • target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None.

  • opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.

Returns:

DQN state.

Return type:

DQNState

name: str = 'dqn'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (DQNRunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None.

  • deterministic (bool) – Return deterministic action. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one full training.

Parameters:
  • runner_state (DQNRunnerState) – DQN runner state.

  • buffer_state (PrioritisedTrajectoryBufferState) – Buffer state.

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.

Returns:

Tuple of DQN algorithm state and training result.

Return type:

DQNTrainReturnT

update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]

Update the Q-network.

Parameters:
  • train_state (DQNTrainState) – DQN training state.

  • observations (jnp.ndarray) – Batch of observations.

  • actions (jnp.ndarray) – Batch of actions.

  • next_observations (jnp.ndarray) – Batch of next observations.

  • rewards (jnp.ndarray) – Batch of rewards.

  • dones (jnp.ndarray) – Batch of dones.

Returns:

Tuple of (train_state, loss, td_error, grads).

Return type:

tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]

class arlbench.core.algorithms.dqn.DQNMetrics(loss: jnp.ndarray, grads: jnp.ndarray | tuple, td_error: jnp.ndarray)[source]

Bases: NamedTuple

DQN metrics returned by train function. Consists of (loss, grads, td_error).

grads: Array | tuple

Alias for field number 1

loss: Array

Alias for field number 0

td_error: Array

Alias for field number 2

class arlbench.core.algorithms.dqn.DQNRunnerState(rng: chex.PRNGKey, train_state: DQNTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: jnp.ndarray, global_step: int)[source]

Bases: NamedTuple

DQN runner state. Consists of (rng, train_state, env_state, obs, global_step).

env_state: Any

Alias for field number 3

global_step: int

Alias for field number 5

normalizer_state: RunningStatisticsState

Alias for field number 2

obs: Array

Alias for field number 4

rng: Array

Alias for field number 0

train_state: DQNTrainState

Alias for field number 1

class arlbench.core.algorithms.dqn.DQNState(runner_state: DQNRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]

Bases: NamedTuple

DQN algorithm state. Consists of (runner_state, buffer_state).

buffer_state: PrioritisedTrajectoryBufferState

Alias for field number 1

runner_state: DQNRunnerState

Alias for field number 0

arlbench.core.algorithms.dqn.DQNTrainReturnT

alias of tuple[DQNState, DQNTrainingResult]

class arlbench.core.algorithms.dqn.DQNTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: DQNMetrics | None)[source]

Bases: NamedTuple

DQN training result. Consists of (eval_rewards, trajectories, metrics).

eval_rewards: Array

Alias for field number 0

metrics: DQNMetrics | None

Alias for field number 2

trajectories: Transition | None

Alias for field number 1