arlbench.core.algorithms.dqn package¶
Submodules¶
arlbench.core.algorithms.dqn.dqn module¶
DQN algorithm.
- class arlbench.core.algorithms.dqn.dqn.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Deep Q-Network (DQN).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all possible checkpointing options for DQN.
- Parameters:
runner_state (DQNRunnerState) – Algorithm runner state.
train_result (DQNTrainingResult | None) – Training result.
- Returns:
Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories].
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default hyperparameter configuration for DQN.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default NAS configuration for DQN.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter optimization (HPO) configuration space for DQN.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search (NAS) configuration space for DQN.
- Return type:
ConfigurationSpace
- init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]¶
Initializes DQN state. Passed parameters are not initialized and included in the final state.
- Parameters:
rng (chex.PRNGKey) – Random generator key.
buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None.
network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None.
target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None.
opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.
- Returns:
DQN state.
- Return type:
-
name:
str
= 'dqn'¶
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (DQNRunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None.
deterministic (bool) – Return deterministic action. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
Performs one full training.
- Parameters:
runner_state (DQNRunnerState) – DQN runner state.
buffer_state (PrioritisedTrajectoryBufferState) – Buffer state.
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.
- Returns:
Tuple of DQN algorithm state and training result.
- Return type:
DQNTrainReturnT
- update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]¶
Update the Q-network.
- Parameters:
train_state (DQNTrainState) – DQN training state.
observations (jnp.ndarray) – Batch of observations.
actions (jnp.ndarray) – Batch of actions.
next_observations (jnp.ndarray) – Batch of next observations.
rewards (jnp.ndarray) – Batch of rewards.
dones (jnp.ndarray) – Batch of dones.
- Returns:
Tuple of (train_state, loss, td_error, grads).
- Return type:
tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]
- class arlbench.core.algorithms.dqn.dqn.DQNMetrics(loss: jnp.ndarray, grads: jnp.ndarray | tuple, td_error: jnp.ndarray)[source]¶
Bases:
NamedTuple
DQN metrics returned by train function. Consists of (loss, grads, td_error).
-
grads:
Array
|tuple
¶ Alias for field number 1
-
loss:
Array
¶ Alias for field number 0
-
td_error:
Array
¶ Alias for field number 2
-
grads:
- class arlbench.core.algorithms.dqn.dqn.DQNRunnerState(rng: chex.PRNGKey, train_state: DQNTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: jnp.ndarray, global_step: int)[source]¶
Bases:
NamedTuple
DQN runner state. Consists of (rng, train_state, env_state, obs, global_step).
-
env_state:
Any
¶ Alias for field number 3
-
global_step:
int
¶ Alias for field number 5
-
normalizer_state:
RunningStatisticsState
¶ Alias for field number 2
-
obs:
Array
¶ Alias for field number 4
-
rng:
Array
¶ Alias for field number 0
-
train_state:
DQNTrainState
¶ Alias for field number 1
-
env_state:
- class arlbench.core.algorithms.dqn.dqn.DQNState(runner_state: DQNRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]¶
Bases:
NamedTuple
DQN algorithm state. Consists of (runner_state, buffer_state).
-
buffer_state:
PrioritisedTrajectoryBufferState
¶ Alias for field number 1
-
runner_state:
DQNRunnerState
¶ Alias for field number 0
-
buffer_state:
- class arlbench.core.algorithms.dqn.dqn.DQNTrainState(step, apply_fn, params, tx, opt_state, target_params=None)[source]¶
Bases:
TrainState
DQN training state.
- classmethod create_with_opt_state(*, apply_fn, params, target_params, tx, opt_state, **kwargs)[source]¶
Creates a DQN training state with the given optimizer state.
- opt_state: optax.OptState¶
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- target_params: None | chex.Array | dict | FrozenDict = None¶
- class arlbench.core.algorithms.dqn.dqn.DQNTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: DQNMetrics | None)[source]¶
Bases:
NamedTuple
DQN training result. Consists of (eval_rewards, trajectories, metrics).
-
eval_rewards:
Array
¶ Alias for field number 0
-
metrics:
DQNMetrics
|None
¶ Alias for field number 2
-
trajectories:
Transition
|None
¶ Alias for field number 1
-
eval_rewards:
- class arlbench.core.algorithms.dqn.dqn.Transition(done: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray, obs: jnp.ndarray, info: dict)[source]¶
Bases:
NamedTuple
DQN Transition. Consists of (done, action, reward, obs, info).
-
action:
Array
¶ Alias for field number 1
-
done:
Array
¶ Alias for field number 0
-
info:
dict
¶ Alias for field number 4
-
obs:
Array
¶ Alias for field number 3
-
reward:
Array
¶ Alias for field number 2
-
action:
arlbench.core.algorithms.dqn.models module¶
Q-Networks for DQN.
- class arlbench.core.algorithms.dqn.models.CNNQ(action_dim, activation='tanh', hidden_size=512, discrete=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
A CNN-based Q-Network for DQN.
- action_dim: int¶
- activation: str = 'tanh'¶
- discrete: bool = True¶
-
name:
Optional
[str
] = None¶
-
parent:
Union
[Type
[Module
],Scope
,Type
[_Sentinel
],None
] = None¶
- scope: Optional[Scope] = None¶
- class arlbench.core.algorithms.dqn.models.MLPQ(action_dim, activation='tanh', hidden_size=64, discrete=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
An MLP-based Q-Network for DQN.
- action_dim: int¶
- activation: str = 'tanh'¶
- discrete: bool = True¶
-
name:
Optional
[str
] = None¶
-
parent:
Union
[Type
[Module
],Scope
,Type
[_Sentinel
],None
] = None¶
- scope: Optional[Scope] = None¶
Module contents¶
- class arlbench.core.algorithms.dqn.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Deep Q-Network (DQN).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all possible checkpointing options for DQN.
- Parameters:
runner_state (DQNRunnerState) – Algorithm runner state.
train_result (DQNTrainingResult | None) – Training result.
- Returns:
Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories].
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default hyperparameter configuration for DQN.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default NAS configuration for DQN.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter optimization (HPO) configuration space for DQN.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search (NAS) configuration space for DQN.
- Return type:
ConfigurationSpace
- init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]¶
Initializes DQN state. Passed parameters are not initialized and included in the final state.
- Parameters:
rng (chex.PRNGKey) – Random generator key.
buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None.
network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None.
target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None.
opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None.
- Returns:
DQN state.
- Return type:
-
name:
str
= 'dqn'¶
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (DQNRunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None.
deterministic (bool) – Return deterministic action. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
Performs one full training.
- Parameters:
runner_state (DQNRunnerState) – DQN runner state.
buffer_state (PrioritisedTrajectoryBufferState) – Buffer state.
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10.
- Returns:
Tuple of DQN algorithm state and training result.
- Return type:
DQNTrainReturnT
- update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]¶
Update the Q-network.
- Parameters:
train_state (DQNTrainState) – DQN training state.
observations (jnp.ndarray) – Batch of observations.
actions (jnp.ndarray) – Batch of actions.
next_observations (jnp.ndarray) – Batch of next observations.
rewards (jnp.ndarray) – Batch of rewards.
dones (jnp.ndarray) – Batch of dones.
- Returns:
Tuple of (train_state, loss, td_error, grads).
- Return type:
tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]
- class arlbench.core.algorithms.dqn.DQNMetrics(loss: jnp.ndarray, grads: jnp.ndarray | tuple, td_error: jnp.ndarray)[source]¶
Bases:
NamedTuple
DQN metrics returned by train function. Consists of (loss, grads, td_error).
-
grads:
Array
|tuple
¶ Alias for field number 1
-
loss:
Array
¶ Alias for field number 0
-
td_error:
Array
¶ Alias for field number 2
-
grads:
- class arlbench.core.algorithms.dqn.DQNRunnerState(rng: chex.PRNGKey, train_state: DQNTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: jnp.ndarray, global_step: int)[source]¶
Bases:
NamedTuple
DQN runner state. Consists of (rng, train_state, env_state, obs, global_step).
-
env_state:
Any
¶ Alias for field number 3
-
global_step:
int
¶ Alias for field number 5
-
normalizer_state:
RunningStatisticsState
¶ Alias for field number 2
-
obs:
Array
¶ Alias for field number 4
-
rng:
Array
¶ Alias for field number 0
-
train_state:
DQNTrainState
¶ Alias for field number 1
-
env_state:
- class arlbench.core.algorithms.dqn.DQNState(runner_state: DQNRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]¶
Bases:
NamedTuple
DQN algorithm state. Consists of (runner_state, buffer_state).
-
buffer_state:
PrioritisedTrajectoryBufferState
¶ Alias for field number 1
-
runner_state:
DQNRunnerState
¶ Alias for field number 0
-
buffer_state:
- arlbench.core.algorithms.dqn.DQNTrainReturnT¶
alias of
tuple
[DQNState
,DQNTrainingResult
]
- class arlbench.core.algorithms.dqn.DQNTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: DQNMetrics | None)[source]¶
Bases:
NamedTuple
DQN training result. Consists of (eval_rewards, trajectories, metrics).
-
eval_rewards:
Array
¶ Alias for field number 0
-
metrics:
DQNMetrics
|None
¶ Alias for field number 2
-
trajectories:
Transition
|None
¶ Alias for field number 1
-
eval_rewards: