arlbench.core.algorithms.dqn.dqn¶
DQN algorithm.
Classes
| 
 | JAX-based implementation of Deep Q-Network (DQN). | 
| 
 | DQN metrics returned by train function. | 
| 
 | DQN runner state. | 
| 
 | DQN algorithm state. | 
| 
 | DQN training state. | 
| 
 | DQN training result. | 
| 
 | DQN Transition. | 
- class arlbench.core.algorithms.dqn.dqn.DQN(hpo_config, env, eval_env=None, deterministic_eval=True, eval_eps=0.05, cnn_policy=False, nas_config=None, track_trajectories=False, track_metrics=False)[source]¶
- Bases: - Algorithm- JAX-based implementation of Deep Q-Network (DQN). - static get_checkpoint_factory(runner_state, train_result)[source]¶
- Creates a factory dictionary of all possible checkpointing options for DQN. - Parameters:
- runner_state (DQNRunnerState) – Algorithm runner state. 
- train_result (DQNTrainingResult | None) – Training result. 
 
- Returns:
- Dictionary of factory functions containing [opt_state, params, target_params, loss, trajectories]. 
- Return type:
- dict[str, Callable] 
 
 - static get_default_hpo_config()[source]¶
- Returns the default hyperparameter configuration for DQN. - Return type:
- Configuration
 
 - static get_default_nas_config()[source]¶
- Returns the default NAS configuration for DQN. - Return type:
- Configuration
 
 - static get_hpo_config_space(seed=None)[source]¶
- Returns the hyperparameter optimization (HPO) configuration space for DQN. - Return type:
- ConfigurationSpace
 
 - static get_nas_config_space(seed=None)[source]¶
- Returns the neural architecture search (NAS) configuration space for DQN. - Return type:
- ConfigurationSpace
 
 - init(rng, buffer_state=None, network_params=None, target_params=None, opt_state=None)[source]¶
- Initializes DQN state. Passed parameters are not initialized and included in the final state. - Parameters:
- rng (chex.PRNGKey) – Random generator key. 
- buffer_state (PrioritisedTrajectoryBufferState | None, optional) – Buffer state. Defaults to None. 
- network_params (FrozenDict | dict | None, optional) – Networks parameters. Defaults to None. 
- target_params (FrozenDict | dict | None, optional) – Target network parameters. Defaults to None. 
- opt_state (optax.OptState | None, optional) – Optimizer state. Defaults to None. 
 
- Returns:
- DQN state. 
- Return type:
 
 - predict(runner_state, obs, rng, deterministic=True)[source]¶
- Predict action(s) based on the current observation(s). - Parameters:
- runner_state (DQNRunnerState) – Algorithm runner state. 
- obs (jnp.ndarray) – Observation(s). 
- rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithms. Defaults to None. 
- deterministic (bool) – Return deterministic action. Defaults to True. 
 
- Returns:
- Action(s). 
- Return type:
- jnp.ndarray 
 
 - train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
- Performs one full training. - Parameters:
- runner_state (DQNRunnerState) – DQN runner state. 
- buffer_state (PrioritisedTrajectoryBufferState) – Buffer state. 
- n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000. 
- n_eval_steps (int, optional) – Number of evaluation steps during training. Defaults to 100. 
- n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. Defaults to 10. 
 
- Returns:
- Tuple of DQN algorithm state and training result. 
- Return type:
- DQNTrainReturnT 
 
 - update(train_state, observations, is_weights, actions, next_observations, rewards, dones)[source]¶
- Update the Q-network. - Parameters:
- train_state (DQNTrainState) – DQN training state. 
- observations (jnp.ndarray) – Batch of observations. 
- actions (jnp.ndarray) – Batch of actions. 
- next_observations (jnp.ndarray) – Batch of next observations. 
- rewards (jnp.ndarray) – Batch of rewards. 
- dones (jnp.ndarray) – Batch of dones. 
 
- Returns:
- Tuple of (train_state, loss, td_error, grads). 
- Return type:
- tuple[DQNTrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray] 
 
 
- class arlbench.core.algorithms.dqn.dqn.DQNMetrics(loss: jnp.ndarray, grads: jnp.ndarray | tuple, td_error: jnp.ndarray)[source]¶
- Bases: - NamedTuple- DQN metrics returned by train function. Consists of (loss, grads, td_error). - 
grads: Array|tuple¶
- Alias for field number 1 
 - 
loss: Array¶
- Alias for field number 0 
 - 
td_error: Array¶
- Alias for field number 2 
 
- 
grads: 
- class arlbench.core.algorithms.dqn.dqn.DQNRunnerState(rng: chex.PRNGKey, train_state: DQNTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: jnp.ndarray, global_step: int)[source]¶
- Bases: - NamedTuple- DQN runner state. Consists of (rng, train_state, env_state, obs, global_step). - 
env_state: Any¶
- Alias for field number 3 
 - 
global_step: int¶
- Alias for field number 5 
 - 
normalizer_state: RunningStatisticsState¶
- Alias for field number 2 
 - 
obs: Array¶
- Alias for field number 4 
 - 
rng: Array¶
- Alias for field number 0 
 - 
train_state: DQNTrainState¶
- Alias for field number 1 
 
- 
env_state: 
- class arlbench.core.algorithms.dqn.dqn.DQNState(runner_state: DQNRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]¶
- Bases: - NamedTuple- DQN algorithm state. Consists of (runner_state, buffer_state). - 
buffer_state: PrioritisedTrajectoryBufferState¶
- Alias for field number 1 
 - 
runner_state: DQNRunnerState¶
- Alias for field number 0 
 
- 
buffer_state: 
- class arlbench.core.algorithms.dqn.dqn.DQNTrainState(step, apply_fn, params, tx, opt_state, target_params=None)[source]¶
- Bases: - TrainState- DQN training state. - classmethod create_with_opt_state(*, apply_fn, params, target_params, tx, opt_state, **kwargs)[source]¶
- Creates a DQN training state with the given optimizer state. 
 - replace(**updates)¶
- “Returns a new object replacing the specified fields with new values. 
 
- class arlbench.core.algorithms.dqn.dqn.DQNTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: DQNMetrics | None)[source]¶
- Bases: - NamedTuple- DQN training result. Consists of (eval_rewards, trajectories, metrics). - 
eval_rewards: Array¶
- Alias for field number 0 
 - 
metrics: DQNMetrics|None¶
- Alias for field number 2 
 - 
trajectories: Transition|None¶
- Alias for field number 1 
 
- 
eval_rewards: 
- class arlbench.core.algorithms.dqn.dqn.Transition(done: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray, obs: jnp.ndarray, info: dict)[source]¶
- Bases: - NamedTuple- DQN Transition. Consists of (done, action, reward, obs, info). - 
action: Array¶
- Alias for field number 1 
 - 
done: Array¶
- Alias for field number 0 
 - 
info: dict¶
- Alias for field number 4 
 - 
obs: Array¶
- Alias for field number 3 
 - 
reward: Array¶
- Alias for field number 2 
 
- 
action: