arlbench.core.algorithms.sac.sac¶
SAC algorithm.
Classes
|
JAX-based implementation of Soft-Actor-Critic (SAC). |
|
SAC metrics returned by train function. |
|
SAC runner state. |
|
SAC algorithm state. |
|
SAC training state. |
|
SAC training result. |
|
SAC Transition. |
- class arlbench.core.algorithms.sac.sac.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]¶
Bases:
Algorithm
JAX-based implementation of Soft-Actor-Critic (SAC).
- static get_checkpoint_factory(runner_state, train_result)[source]¶
Creates a factory dictionary of all posssible checkpointing options for SAC.
- Parameters:
runner_state (SACRunnerState) – Algorithm runner state.
train_result (SACTrainingResult | None) – Training result.
- Returns:
- Dictionary of factory functions containing:
actor_opt_state
critic_opt_state
alpha_opt_state
actor_network_params
critic_network_params
critic_target_params
alpha_network_params
actor_loss
critic_loss
alpha_loss
trajectories
- Return type:
dict[str, Callable]
- static get_default_hpo_config()[source]¶
Returns the default HPO configuration for SAC.
- Return type:
Configuration
- static get_default_nas_config()[source]¶
Returns the default NAS configuration for SAC.
- Return type:
Configuration
- static get_hpo_config_space(seed=None)[source]¶
Returns the hyperparameter configuration space for SAC.
- Return type:
ConfigurationSpace
- static get_nas_config_space(seed=None)[source]¶
Returns the neural architecture search (NAS) configuration space for SAC.
- Return type:
ConfigurationSpace
- init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]¶
Initializes SAC state. Passed parameters are not initialized and included in the final state.
- Parameters:
actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None.
critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None.
critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None.
alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None.
actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None.
critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None.
alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None.
- Returns:
SAC state.
- Return type:
- predict(runner_state, obs, rng, deterministic=True)[source]¶
Predict action(s) based on the current observation(s).
- Parameters:
runner_state (SACRunnerState) – Algorithm runner state.
obs (jnp.ndarray) – Observation(s).
rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None.
deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True.
- Returns:
Action(s).
- Return type:
jnp.ndarray
- train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]¶
Performs one full training.
- Parameters:
runner_state (SACTrainReturnT) – SAC runner state.
_ (None) – Unused parameter (buffer_state in other algorithms).
n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.
n_eval_steps (int, optional) – Number of evaluation steps during training.
n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.
- Returns:
Tuple of PPO algorithm state and training result.
- Return type:
SACTrainReturnT
- update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
Updates the actor network parameters.
- Parameters:
actor_train_state (SACTrainState) – Actor train state.
critic_train_state (SACTrainState) – Critic train state.
alpha_train_state (SACTrainState) – Alpha train state.
experience (TimeStep) – Experience (batch of TimeSteps).
is_weights (jnp.ndarray) – Whether to use weights for PER or not.
rng (chex.PRNGKey) – Random number generator key.
- Returns:
_description_
- Return type:
tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]
- update_alpha(alpha_train_state, entropy)[source]¶
Update alpha network parameters.
- Parameters:
alpha_train_state (SACTrainState) – Alpha training state.
entropy (jnp.ndarray) – Entropy values.
- Returns:
Updated trainingi state and metrics.
- Return type:
tuple[SACTrainState, jnp.ndarray]
- update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]¶
Updates the critic network parameters.
- Parameters:
actor_train_state (SACTrainState) – Actor train state.
critic_train_state (SACTrainState) – Critic train state.
alpha_train_state (SACTrainState) – Alpha train state.
experience (Transition) – Experience (batch of transitions).
rng (chex.PRNGKey) – Random number generator key.
- Returns:
Updated training state and metrics.
- Return type:
tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]
- class arlbench.core.algorithms.sac.sac.SACMetrics(actor_loss: jnp.ndarray, critic_loss: jnp.ndarray, alpha_loss: jnp.ndarray, td_error: jnp.ndarray, actor_grads: FrozenDict, critic_grads: FrozenDict)[source]¶
Bases:
NamedTuple
SAC metrics returned by train function. Consists of (actor_loss, critic_loss, alpha_loss, td_error, actor_grads, critic_grads).
-
actor_grads:
FrozenDict
¶ Alias for field number 4
-
actor_loss:
Array
¶ Alias for field number 0
-
alpha_loss:
Array
¶ Alias for field number 2
-
critic_grads:
FrozenDict
¶ Alias for field number 5
-
critic_loss:
Array
¶ Alias for field number 1
-
td_error:
Array
¶ Alias for field number 3
-
actor_grads:
- class arlbench.core.algorithms.sac.sac.SACRunnerState(rng: chex.PRNGKey, actor_train_state: SACTrainState, critic_train_state: SACTrainState, alpha_train_state: SACTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]¶
Bases:
NamedTuple
SAC runner state. Consists of (rng, actor_train_state, critic_train_state, alpha_train_state, env_state, obs, global_step).
-
actor_train_state:
SACTrainState
¶ Alias for field number 1
-
alpha_train_state:
SACTrainState
¶ Alias for field number 3
-
critic_train_state:
SACTrainState
¶ Alias for field number 2
-
env_state:
Any
¶ Alias for field number 5
-
global_step:
int
¶ Alias for field number 7
-
normalizer_state:
RunningStatisticsState
¶ Alias for field number 4
-
obs:
Union
[Array
,ndarray
,bool_
,number
]¶ Alias for field number 6
-
rng:
Array
¶ Alias for field number 0
-
actor_train_state:
- class arlbench.core.algorithms.sac.sac.SACState(runner_state: SACRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]¶
Bases:
NamedTuple
SAC algorithm state. Consists of (runner_state, buffer_state).
-
buffer_state:
PrioritisedTrajectoryBufferState
¶ Alias for field number 1
-
runner_state:
SACRunnerState
¶ Alias for field number 0
-
buffer_state:
- class arlbench.core.algorithms.sac.sac.SACTrainState(step, apply_fn, params, tx, opt_state, target_params=None)[source]¶
Bases:
TrainState
SAC training state.
- classmethod create_with_opt_state(*, apply_fn, params, target_params, tx, opt_state, **kwargs)[source]¶
Instantiates with optimizer state.
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- class arlbench.core.algorithms.sac.sac.SACTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: SACMetrics | None)[source]¶
Bases:
NamedTuple
SAC training result. Consists of (eval_rewards, trajectories, metrics).
-
eval_rewards:
Array
¶ Alias for field number 0
-
metrics:
SACMetrics
|None
¶ Alias for field number 2
-
trajectories:
Transition
|None
¶ Alias for field number 1
-
eval_rewards:
- class arlbench.core.algorithms.sac.sac.Transition(done: jnp.ndarray, action: jnp.ndarray, value: jnp.ndarray, reward: jnp.ndarray, obs: jnp.ndarray, info: jnp.ndarray)[source]¶
Bases:
NamedTuple
SAC Transition. Consists of (done, action, value, reward, obs, info).
-
action:
Array
¶ Alias for field number 1
-
done:
Array
¶ Alias for field number 0
-
info:
Array
¶ Alias for field number 5
-
obs:
Array
¶ Alias for field number 4
-
reward:
Array
¶ Alias for field number 3
-
value:
Array
¶ Alias for field number 2
-
action: