arlbench.core.algorithms.sac package¶
Submodules¶
arlbench.core.algorithms.sac.models module¶
SAC models for the actor and critic networks.
- class arlbench.core.algorithms.sac.models.AlphaCoef(alpha_init=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
- Bases: - Module- Alpha coefficient for SAC. - alpha_init: float = 1.0¶
 - 
name: Optional[str] = None¶
 - 
parent: Union[Type[Module],Scope,Type[_Sentinel],None] = None¶
 - scope: Optional[Scope] = None¶
 
- class arlbench.core.algorithms.sac.models.SACCNNActor(action_dim, activation, hidden_size=64, log_std_min=-20, log_std_max=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
- Bases: - Module- A CNN-based actor network for SAC. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48. - action_dim: int¶
 - activation: int¶
 - log_std_max: float = 2¶
 - log_std_min: float = -20¶
 - 
name: Optional[str] = None¶
 - 
parent: Union[Type[Module],Scope,Type[_Sentinel],None] = None¶
 - scope: Optional[Scope] = None¶
 
- class arlbench.core.algorithms.sac.models.SACCNNCritic(action_dim, activation, hidden_size=512, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
- Bases: - Module- A CNN-based critic network for SAC. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48. - action_dim: int¶
 - activation: int¶
 - 
name: Optional[str] = None¶
 - 
parent: Union[Type[Module],Scope,Type[_Sentinel],None] = None¶
 - scope: Optional[Scope] = None¶
 
- class arlbench.core.algorithms.sac.models.SACMLPActor(action_dim, activation, hidden_size=64, log_std_min=-20, log_std_max=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
- Bases: - Module- An MLP-based actor network for PPO. - action_dim: int¶
 - activation: int¶
 - log_std_max: float = 2¶
 - log_std_min: float = -20¶
 - 
name: Optional[str] = None¶
 - 
parent: Union[Type[Module],Scope,Type[_Sentinel],None] = None¶
 - scope: Optional[Scope] = None¶
 
- class arlbench.core.algorithms.sac.models.SACMLPCritic(action_dim, activation, hidden_size=64, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
- Bases: - Module- An MLP-based critic network for SAC. - action_dim: int¶
 - activation: int¶
 - 
name: Optional[str] = None¶
 - 
parent: Union[Type[Module],Scope,Type[_Sentinel],None] = None¶
 - scope: Optional[Scope] = None¶
 
- class arlbench.core.algorithms.sac.models.SACVectorCritic(critic, action_dim, activation, hidden_size=64, n_critics=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
- Bases: - Module- A vectorized critic network for SAC. - action_dim: int¶
 - activation: int¶
 - critic: type[SACMLPCritic] | type[SACCNNCritic]¶
 - n_critics: int = 2¶
 - 
name: Optional[str] = None¶
 - 
parent: Union[Type[Module],Scope,Type[_Sentinel],None] = None¶
 - scope: Optional[Scope] = None¶
 
arlbench.core.algorithms.sac.sac module¶
SAC algorithm.
- class arlbench.core.algorithms.sac.sac.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]¶
- Bases: - Algorithm- JAX-based implementation of Soft-Actor-Critic (SAC). - static get_checkpoint_factory(runner_state, train_result)[source]¶
- Creates a factory dictionary of all posssible checkpointing options for SAC. - Parameters:
- runner_state (SACRunnerState) – Algorithm runner state. 
- train_result (SACTrainingResult | None) – Training result. 
 
- Returns:
- Dictionary of factory functions containing:
- actor_opt_state 
- critic_opt_state 
- alpha_opt_state 
- actor_network_params 
- critic_network_params 
- critic_target_params 
- alpha_network_params 
- actor_loss 
- critic_loss 
- alpha_loss 
- trajectories 
 
 
- Return type:
- dict[str, Callable] 
 
 - static get_default_hpo_config()[source]¶
- Returns the default HPO configuration for SAC. - Return type:
- Configuration
 
 - static get_default_nas_config()[source]¶
- Returns the default NAS configuration for SAC. - Return type:
- Configuration
 
 - static get_hpo_config_space(seed=None)[source]¶
- Returns the hyperparameter configuration space for SAC. - Return type:
- ConfigurationSpace
 
 - static get_nas_config_space(seed=None)[source]¶
- Returns the neural architecture search (NAS) configuration space for SAC. - Return type:
- ConfigurationSpace
 
 - init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]¶
- Initializes SAC state. Passed parameters are not initialized and included in the final state. - Parameters:
- actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None. 
- critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None. 
- critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None. 
- alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None. 
- actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None. 
- critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None. 
- alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None. 
 
- Returns:
- SAC state. 
- Return type:
 
 - 
name: str= 'sac'¶
 - predict(runner_state, obs, rng, deterministic=True)[source]¶
- Predict action(s) based on the current observation(s). - Parameters:
- runner_state (SACRunnerState) – Algorithm runner state. 
- obs (jnp.ndarray) – Observation(s). 
- rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None. 
- deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True. 
 
- Returns:
- Action(s). 
- Return type:
- jnp.ndarray 
 
 - train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
- Performs one full training. - Parameters:
- runner_state (SACTrainReturnT) – SAC runner state. 
- _ (None) – Unused parameter (buffer_state in other algorithms). 
- n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000. 
- n_eval_steps (int, optional) – Number of evaluation steps during training. 
- n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. 
 
- Returns:
- Tuple of PPO algorithm state and training result. 
- Return type:
- SACTrainReturnT 
 
 - update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
- Updates the actor network parameters. - Parameters:
- actor_train_state (SACTrainState) – Actor train state. 
- critic_train_state (SACTrainState) – Critic train state. 
- alpha_train_state (SACTrainState) – Alpha train state. 
- experience (TimeStep) – Experience (batch of TimeSteps). 
- is_weights (jnp.ndarray) – Whether to use weights for PER or not. 
- rng (chex.PRNGKey) – Random number generator key. 
 
- Returns:
- _description_ 
- Return type:
- tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey] 
 
 - update_alpha(alpha_train_state, entropy)[source]¶
- Update alpha network parameters. - Parameters:
- alpha_train_state (SACTrainState) – Alpha training state. 
- entropy (jnp.ndarray) – Entropy values. 
 
- Returns:
- Updated trainingi state and metrics. 
- Return type:
- tuple[SACTrainState, jnp.ndarray] 
 
 - update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
- Updates the critic network parameters. - Parameters:
- actor_train_state (SACTrainState) – Actor train state. 
- critic_train_state (SACTrainState) – Critic train state. 
- alpha_train_state (SACTrainState) – Alpha train state. 
- experience (Transition) – Experience (batch of transitions). 
- rng (chex.PRNGKey) – Random number generator key. 
 
- Returns:
- Updated training state and metrics. 
- Return type:
- tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey] 
 
 
- class arlbench.core.algorithms.sac.sac.SACMetrics(actor_loss: jnp.ndarray, critic_loss: jnp.ndarray, alpha_loss: jnp.ndarray, td_error: jnp.ndarray, actor_grads: FrozenDict, critic_grads: FrozenDict)[source]¶
- Bases: - NamedTuple- SAC metrics returned by train function. Consists of (actor_loss, critic_loss, alpha_loss, td_error, actor_grads, critic_grads). - 
actor_grads: FrozenDict¶
- Alias for field number 4 
 - 
actor_loss: Array¶
- Alias for field number 0 
 - 
alpha_loss: Array¶
- Alias for field number 2 
 - 
critic_grads: FrozenDict¶
- Alias for field number 5 
 - 
critic_loss: Array¶
- Alias for field number 1 
 - 
td_error: Array¶
- Alias for field number 3 
 
- 
actor_grads: 
- class arlbench.core.algorithms.sac.sac.SACRunnerState(rng: chex.PRNGKey, actor_train_state: SACTrainState, critic_train_state: SACTrainState, alpha_train_state: SACTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]¶
- Bases: - NamedTuple- SAC runner state. Consists of (rng, actor_train_state, critic_train_state, alpha_train_state, env_state, obs, global_step). - 
actor_train_state: SACTrainState¶
- Alias for field number 1 
 - 
alpha_train_state: SACTrainState¶
- Alias for field number 3 
 - 
critic_train_state: SACTrainState¶
- Alias for field number 2 
 - 
env_state: Any¶
- Alias for field number 5 
 - 
global_step: int¶
- Alias for field number 7 
 - 
normalizer_state: RunningStatisticsState¶
- Alias for field number 4 
 - 
obs: Union[Array,ndarray,bool_,number]¶
- Alias for field number 6 
 - 
rng: Array¶
- Alias for field number 0 
 
- 
actor_train_state: 
- class arlbench.core.algorithms.sac.sac.SACState(runner_state: SACRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]¶
- Bases: - NamedTuple- SAC algorithm state. Consists of (runner_state, buffer_state). - 
buffer_state: PrioritisedTrajectoryBufferState¶
- Alias for field number 1 
 - 
runner_state: SACRunnerState¶
- Alias for field number 0 
 
- 
buffer_state: 
- class arlbench.core.algorithms.sac.sac.SACTrainState(step, apply_fn, params, tx, opt_state, target_params=None)[source]¶
- Bases: - TrainState- SAC training state. - classmethod create_with_opt_state(*, apply_fn, params, target_params, tx, opt_state, **kwargs)[source]¶
- Instantiates with optimizer state. 
 - network_state = None¶
 - replace(**updates)¶
- “Returns a new object replacing the specified fields with new values. 
 - target_params: None | chex.Array | dict = None¶
 
- class arlbench.core.algorithms.sac.sac.SACTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: SACMetrics | None)[source]¶
- Bases: - NamedTuple- SAC training result. Consists of (eval_rewards, trajectories, metrics). - 
eval_rewards: Array¶
- Alias for field number 0 
 - 
metrics: SACMetrics|None¶
- Alias for field number 2 
 - 
trajectories: Transition|None¶
- Alias for field number 1 
 
- 
eval_rewards: 
- class arlbench.core.algorithms.sac.sac.Transition(done: jnp.ndarray, action: jnp.ndarray, value: jnp.ndarray, reward: jnp.ndarray, obs: jnp.ndarray, info: jnp.ndarray)[source]¶
- Bases: - NamedTuple- SAC Transition. Consists of (done, action, value, reward, obs, info). - 
action: Array¶
- Alias for field number 1 
 - 
done: Array¶
- Alias for field number 0 
 - 
info: Array¶
- Alias for field number 5 
 - 
obs: Array¶
- Alias for field number 4 
 - 
reward: Array¶
- Alias for field number 3 
 - 
value: Array¶
- Alias for field number 2 
 
- 
action: 
Module contents¶
- class arlbench.core.algorithms.sac.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]¶
- Bases: - Algorithm- JAX-based implementation of Soft-Actor-Critic (SAC). - static get_checkpoint_factory(runner_state, train_result)[source]¶
- Creates a factory dictionary of all posssible checkpointing options for SAC. - Parameters:
- runner_state (SACRunnerState) – Algorithm runner state. 
- train_result (SACTrainingResult | None) – Training result. 
 
- Returns:
- Dictionary of factory functions containing:
- actor_opt_state 
- critic_opt_state 
- alpha_opt_state 
- actor_network_params 
- critic_network_params 
- critic_target_params 
- alpha_network_params 
- actor_loss 
- critic_loss 
- alpha_loss 
- trajectories 
 
 
- Return type:
- dict[str, Callable] 
 
 - static get_default_hpo_config()[source]¶
- Returns the default HPO configuration for SAC. - Return type:
- Configuration
 
 - static get_default_nas_config()[source]¶
- Returns the default NAS configuration for SAC. - Return type:
- Configuration
 
 - static get_hpo_config_space(seed=None)[source]¶
- Returns the hyperparameter configuration space for SAC. - Return type:
- ConfigurationSpace
 
 - static get_nas_config_space(seed=None)[source]¶
- Returns the neural architecture search (NAS) configuration space for SAC. - Return type:
- ConfigurationSpace
 
 - init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]¶
- Initializes SAC state. Passed parameters are not initialized and included in the final state. - Parameters:
- actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None. 
- critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None. 
- critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None. 
- alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None. 
- actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None. 
- critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None. 
- alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None. 
 
- Returns:
- SAC state. 
- Return type:
 
 - 
name: str= 'sac'¶
 - predict(runner_state, obs, rng, deterministic=True)[source]¶
- Predict action(s) based on the current observation(s). - Parameters:
- runner_state (SACRunnerState) – Algorithm runner state. 
- obs (jnp.ndarray) – Observation(s). 
- rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None. 
- deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True. 
 
- Returns:
- Action(s). 
- Return type:
- jnp.ndarray 
 
 - train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
- Performs one full training. - Parameters:
- runner_state (SACTrainReturnT) – SAC runner state. 
- _ (None) – Unused parameter (buffer_state in other algorithms). 
- n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000. 
- n_eval_steps (int, optional) – Number of evaluation steps during training. 
- n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training. 
 
- Returns:
- Tuple of PPO algorithm state and training result. 
- Return type:
- SACTrainReturnT 
 
 - update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
- Updates the actor network parameters. - Parameters:
- actor_train_state (SACTrainState) – Actor train state. 
- critic_train_state (SACTrainState) – Critic train state. 
- alpha_train_state (SACTrainState) – Alpha train state. 
- experience (TimeStep) – Experience (batch of TimeSteps). 
- is_weights (jnp.ndarray) – Whether to use weights for PER or not. 
- rng (chex.PRNGKey) – Random number generator key. 
 
- Returns:
- _description_ 
- Return type:
- tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey] 
 
 - update_alpha(alpha_train_state, entropy)[source]¶
- Update alpha network parameters. - Parameters:
- alpha_train_state (SACTrainState) – Alpha training state. 
- entropy (jnp.ndarray) – Entropy values. 
 
- Returns:
- Updated trainingi state and metrics. 
- Return type:
- tuple[SACTrainState, jnp.ndarray] 
 
 - update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
- Updates the critic network parameters. - Parameters:
- actor_train_state (SACTrainState) – Actor train state. 
- critic_train_state (SACTrainState) – Critic train state. 
- alpha_train_state (SACTrainState) – Alpha train state. 
- experience (Transition) – Experience (batch of transitions). 
- rng (chex.PRNGKey) – Random number generator key. 
 
- Returns:
- Updated training state and metrics. 
- Return type:
- tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey] 
 
 
- class arlbench.core.algorithms.sac.SACMetrics(actor_loss: jnp.ndarray, critic_loss: jnp.ndarray, alpha_loss: jnp.ndarray, td_error: jnp.ndarray, actor_grads: FrozenDict, critic_grads: FrozenDict)[source]¶
- Bases: - NamedTuple- SAC metrics returned by train function. Consists of (actor_loss, critic_loss, alpha_loss, td_error, actor_grads, critic_grads). - 
actor_grads: FrozenDict¶
- Alias for field number 4 
 - 
actor_loss: Array¶
- Alias for field number 0 
 - 
alpha_loss: Array¶
- Alias for field number 2 
 - 
critic_grads: FrozenDict¶
- Alias for field number 5 
 - 
critic_loss: Array¶
- Alias for field number 1 
 - 
td_error: Array¶
- Alias for field number 3 
 
- 
actor_grads: 
- class arlbench.core.algorithms.sac.SACRunnerState(rng: chex.PRNGKey, actor_train_state: SACTrainState, critic_train_state: SACTrainState, alpha_train_state: SACTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]¶
- Bases: - NamedTuple- SAC runner state. Consists of (rng, actor_train_state, critic_train_state, alpha_train_state, env_state, obs, global_step). - 
actor_train_state: SACTrainState¶
- Alias for field number 1 
 - 
alpha_train_state: SACTrainState¶
- Alias for field number 3 
 - 
critic_train_state: SACTrainState¶
- Alias for field number 2 
 - 
env_state: Any¶
- Alias for field number 5 
 - 
global_step: int¶
- Alias for field number 7 
 - 
normalizer_state: RunningStatisticsState¶
- Alias for field number 4 
 - 
obs: Union[Array,ndarray,bool_,number]¶
- Alias for field number 6 
 - 
rng: Array¶
- Alias for field number 0 
 
- 
actor_train_state: 
- class arlbench.core.algorithms.sac.SACState(runner_state: SACRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]¶
- Bases: - NamedTuple- SAC algorithm state. Consists of (runner_state, buffer_state). - 
buffer_state: PrioritisedTrajectoryBufferState¶
- Alias for field number 1 
 - 
runner_state: SACRunnerState¶
- Alias for field number 0 
 
- 
buffer_state: 
- arlbench.core.algorithms.sac.SACTrainReturnT¶
- alias of - tuple[- SACState,- SACTrainingResult]
- class arlbench.core.algorithms.sac.SACTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: SACMetrics | None)[source]¶
- Bases: - NamedTuple- SAC training result. Consists of (eval_rewards, trajectories, metrics). - 
eval_rewards: Array¶
- Alias for field number 0 
 - 
metrics: SACMetrics|None¶
- Alias for field number 2 
 - 
trajectories: Transition|None¶
- Alias for field number 1 
 
- 
eval_rewards: