arlbench.core.algorithms.sac package

Submodules

arlbench.core.algorithms.sac.models module

SAC models for the actor and critic networks.

class arlbench.core.algorithms.sac.models.AlphaCoef(alpha_init=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Alpha coefficient for SAC.

__call__()[source]

Returns the alpha coefficient.

Return type:

Array

alpha_init: float = 1.0
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the alpha coefficient.

class arlbench.core.algorithms.sac.models.SACCNNActor(action_dim, activation, hidden_size=64, log_std_min=-20, log_std_max=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A CNN-based actor network for SAC. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48.

__call__(x)[source]

Applies the actor to the input.

action_dim: int
activation: int
hidden_size: int = 64
log_std_max: float = 2
log_std_min: float = -20
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the actor network.

class arlbench.core.algorithms.sac.models.SACCNNCritic(action_dim, activation, hidden_size=512, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A CNN-based critic network for SAC. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48.

__call__(x, action)[source]

Applies the critic to the input.

action_dim: int
activation: int
hidden_size: int = 512
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the critic network.

class arlbench.core.algorithms.sac.models.SACMLPActor(action_dim, activation, hidden_size=64, log_std_min=-20, log_std_max=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

An MLP-based actor network for PPO.

__call__(x)[source]

Applies the actor to the input.

action_dim: int
activation: int
hidden_size: int = 64
log_std_max: float = 2
log_std_min: float = -20
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the actor network.

class arlbench.core.algorithms.sac.models.SACMLPCritic(action_dim, activation, hidden_size=64, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

An MLP-based critic network for SAC.

__call__(x, action)[source]

Applies the critic to the input.

action_dim: int
activation: int
hidden_size: int = 64
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
setup()[source]

Initializes the critic network.

class arlbench.core.algorithms.sac.models.SACVectorCritic(critic, action_dim, activation, hidden_size=64, n_critics=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A vectorized critic network for SAC.

__call__(x, action)[source]

Applies the critic to the input.

action_dim: int
activation: int
critic: type[SACMLPCritic] | type[SACCNNCritic]
hidden_size: int = 64
n_critics: int = 2
name: Optional[str] = None
parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None
scope: Optional[Scope] = None
class arlbench.core.algorithms.sac.models.TanhTransformedDistribution(*args, **kwargs)[source]

Bases: Transformed

Tanh transformation of a distrax distribution.

mode()[source]

Returns the mode of the distribution.

Return type:

Array

arlbench.core.algorithms.sac.sac module

SAC algorithm.

class arlbench.core.algorithms.sac.sac.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]

Bases: Algorithm

JAX-based implementation of Soft-Actor-Critic (SAC).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for SAC.

Parameters:
Returns:

Dictionary of factory functions containing:
  • actor_opt_state

  • critic_opt_state

  • alpha_opt_state

  • actor_network_params

  • critic_network_params

  • critic_target_params

  • alpha_network_params

  • actor_loss

  • critic_loss

  • alpha_loss

  • trajectories

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default HPO configuration for SAC.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default NAS configuration for SAC.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space for SAC.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search (NAS) configuration space for SAC.

Return type:

ConfigurationSpace

init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]

Initializes SAC state. Passed parameters are not initialized and included in the final state.

Parameters:
  • actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None.

  • critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None.

  • critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None.

  • alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None.

  • actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None.

  • critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None.

  • alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None.

Returns:

SAC state.

Return type:

SACState

name: str = 'sac'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (SACRunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None.

  • deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one full training.

Parameters:
  • runner_state (SACTrainReturnT) – SAC runner state.

  • _ (None) – Unused parameter (buffer_state in other algorithms).

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.

Returns:

Tuple of PPO algorithm state and training result.

Return type:

SACTrainReturnT

update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]

Updates the actor network parameters.

Parameters:
  • actor_train_state (SACTrainState) – Actor train state.

  • critic_train_state (SACTrainState) – Critic train state.

  • alpha_train_state (SACTrainState) – Alpha train state.

  • experience (TimeStep) – Experience (batch of TimeSteps).

  • is_weights (jnp.ndarray) – Whether to use weights for PER or not.

  • rng (chex.PRNGKey) – Random number generator key.

Returns:

_description_

Return type:

tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]

update_alpha(alpha_train_state, entropy)[source]

Update alpha network parameters.

Parameters:
  • alpha_train_state (SACTrainState) – Alpha training state.

  • entropy (jnp.ndarray) – Entropy values.

Returns:

Updated trainingi state and metrics.

Return type:

tuple[SACTrainState, jnp.ndarray]

update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]

Updates the critic network parameters.

Parameters:
  • actor_train_state (SACTrainState) – Actor train state.

  • critic_train_state (SACTrainState) – Critic train state.

  • alpha_train_state (SACTrainState) – Alpha train state.

  • experience (Transition) – Experience (batch of transitions).

  • rng (chex.PRNGKey) – Random number generator key.

Returns:

Updated training state and metrics.

Return type:

tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]

class arlbench.core.algorithms.sac.sac.SACMetrics(actor_loss: jnp.ndarray, critic_loss: jnp.ndarray, alpha_loss: jnp.ndarray, td_error: jnp.ndarray, actor_grads: FrozenDict, critic_grads: FrozenDict)[source]

Bases: NamedTuple

SAC metrics returned by train function. Consists of (actor_loss, critic_loss, alpha_loss, td_error, actor_grads, critic_grads).

actor_grads: FrozenDict

Alias for field number 4

actor_loss: Array

Alias for field number 0

alpha_loss: Array

Alias for field number 2

critic_grads: FrozenDict

Alias for field number 5

critic_loss: Array

Alias for field number 1

td_error: Array

Alias for field number 3

class arlbench.core.algorithms.sac.sac.SACRunnerState(rng: chex.PRNGKey, actor_train_state: SACTrainState, critic_train_state: SACTrainState, alpha_train_state: SACTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]

Bases: NamedTuple

SAC runner state. Consists of (rng, actor_train_state, critic_train_state, alpha_train_state, env_state, obs, global_step).

actor_train_state: SACTrainState

Alias for field number 1

alpha_train_state: SACTrainState

Alias for field number 3

critic_train_state: SACTrainState

Alias for field number 2

env_state: Any

Alias for field number 5

global_step: int

Alias for field number 7

normalizer_state: RunningStatisticsState

Alias for field number 4

obs: Union[Array, ndarray, bool_, number]

Alias for field number 6

rng: Array

Alias for field number 0

class arlbench.core.algorithms.sac.sac.SACState(runner_state: SACRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]

Bases: NamedTuple

SAC algorithm state. Consists of (runner_state, buffer_state).

buffer_state: PrioritisedTrajectoryBufferState

Alias for field number 1

runner_state: SACRunnerState

Alias for field number 0

class arlbench.core.algorithms.sac.sac.SACTrainState(step, apply_fn, params, tx, opt_state, target_params=None)[source]

Bases: TrainState

SAC training state.

classmethod create_with_opt_state(*, apply_fn, params, target_params, tx, opt_state, **kwargs)[source]

Instantiates with optimizer state.

network_state = None
replace(**updates)

“Returns a new object replacing the specified fields with new values.

target_params: None | chex.Array | dict = None
class arlbench.core.algorithms.sac.sac.SACTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: SACMetrics | None)[source]

Bases: NamedTuple

SAC training result. Consists of (eval_rewards, trajectories, metrics).

eval_rewards: Array

Alias for field number 0

metrics: SACMetrics | None

Alias for field number 2

trajectories: Transition | None

Alias for field number 1

class arlbench.core.algorithms.sac.sac.Transition(done: jnp.ndarray, action: jnp.ndarray, value: jnp.ndarray, reward: jnp.ndarray, obs: jnp.ndarray, info: jnp.ndarray)[source]

Bases: NamedTuple

SAC Transition. Consists of (done, action, value, reward, obs, info).

action: Array

Alias for field number 1

done: Array

Alias for field number 0

info: Array

Alias for field number 5

obs: Array

Alias for field number 4

reward: Array

Alias for field number 3

value: Array

Alias for field number 2

Module contents

class arlbench.core.algorithms.sac.SAC(hpo_config, env, eval_env=None, deterministic_eval=True, cnn_policy=False, nas_config=None, track_metrics=False, track_trajectories=False)[source]

Bases: Algorithm

JAX-based implementation of Soft-Actor-Critic (SAC).

static get_checkpoint_factory(runner_state, train_result)[source]

Creates a factory dictionary of all posssible checkpointing options for SAC.

Parameters:
Returns:

Dictionary of factory functions containing:
  • actor_opt_state

  • critic_opt_state

  • alpha_opt_state

  • actor_network_params

  • critic_network_params

  • critic_target_params

  • alpha_network_params

  • actor_loss

  • critic_loss

  • alpha_loss

  • trajectories

Return type:

dict[str, Callable]

static get_default_hpo_config()[source]

Returns the default HPO configuration for SAC.

Return type:

Configuration

static get_default_nas_config()[source]

Returns the default NAS configuration for SAC.

Return type:

Configuration

static get_hpo_config_space(seed=None)[source]

Returns the hyperparameter configuration space for SAC.

Return type:

ConfigurationSpace

static get_nas_config_space(seed=None)[source]

Returns the neural architecture search (NAS) configuration space for SAC.

Return type:

ConfigurationSpace

init(rng, buffer_state=None, actor_network_params=None, critic_network_params=None, critic_target_params=None, alpha_network_params=None, actor_opt_state=None, critic_opt_state=None, alpha_opt_state=None)[source]

Initializes SAC state. Passed parameters are not initialized and included in the final state.

Parameters:
  • actor_network_params (FrozenDict | dict | None, optional) – Actor network parameters. Defaults to None.

  • critic_network_params (FrozenDict | dict | None, optional) – Critic network parameters. Defaults to None.

  • critic_target_params (FrozenDict | dict | None, optional) – Critic target network parameters. Defaults to None.

  • alpha_network_params (FrozenDict | dict | None, optional) – Alpha network parameters. Defaults to None.

  • actor_opt_state (optax.OptState | None, optional) – Actor optimizer state. Defaults to None.

  • critic_opt_state (optax.OptState | None, optional) – Critic optimizer state. Defaults to None.

  • alpha_opt_state (optax.OptState | None, optional) – Alpha optimizer state. Defaults to None.

Returns:

SAC state.

Return type:

SACState

name: str = 'sac'
predict(runner_state, obs, rng, deterministic=True)[source]

Predict action(s) based on the current observation(s).

Parameters:
  • runner_state (SACRunnerState) – Algorithm runner state.

  • obs (jnp.ndarray) – Observation(s).

  • rng (chex.PRNGKey | None, optional) – Not used in DQN. Random generator key in other algorithmsDefaults to None.

  • deterministic (bool) – Not used in DQN. Return deterministic action in other algorithm. Defaults to True.

Returns:

Action(s).

Return type:

jnp.ndarray

train(runner_state, buffer_state, n_total_timesteps=1000000, n_eval_steps=100, n_eval_episodes=10)[source]

Performs one full training.

Parameters:
  • runner_state (SACTrainReturnT) – SAC runner state.

  • _ (None) – Unused parameter (buffer_state in other algorithms).

  • n_total_timesteps (int, optional) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.

  • n_eval_steps (int, optional) – Number of evaluation steps during training.

  • n_eval_episodes (int, optional) – Number of evaluation episodes per evaluation during training.

Returns:

Tuple of PPO algorithm state and training result.

Return type:

SACTrainReturnT

update_actor(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]

Updates the actor network parameters.

Parameters:
  • actor_train_state (SACTrainState) – Actor train state.

  • critic_train_state (SACTrainState) – Critic train state.

  • alpha_train_state (SACTrainState) – Alpha train state.

  • experience (TimeStep) – Experience (batch of TimeSteps).

  • is_weights (jnp.ndarray) – Whether to use weights for PER or not.

  • rng (chex.PRNGKey) – Random number generator key.

Returns:

_description_

Return type:

tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]

update_alpha(alpha_train_state, entropy)[source]

Update alpha network parameters.

Parameters:
  • alpha_train_state (SACTrainState) – Alpha training state.

  • entropy (jnp.ndarray) – Entropy values.

Returns:

Updated trainingi state and metrics.

Return type:

tuple[SACTrainState, jnp.ndarray]

update_critic(actor_train_state, critic_train_state, alpha_train_state, experience, is_weights, rng)[source]

Updates the critic network parameters.

Parameters:
  • actor_train_state (SACTrainState) – Actor train state.

  • critic_train_state (SACTrainState) – Critic train state.

  • alpha_train_state (SACTrainState) – Alpha train state.

  • experience (Transition) – Experience (batch of transitions).

  • rng (chex.PRNGKey) – Random number generator key.

Returns:

Updated training state and metrics.

Return type:

tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]

class arlbench.core.algorithms.sac.SACMetrics(actor_loss: jnp.ndarray, critic_loss: jnp.ndarray, alpha_loss: jnp.ndarray, td_error: jnp.ndarray, actor_grads: FrozenDict, critic_grads: FrozenDict)[source]

Bases: NamedTuple

SAC metrics returned by train function. Consists of (actor_loss, critic_loss, alpha_loss, td_error, actor_grads, critic_grads).

actor_grads: FrozenDict

Alias for field number 4

actor_loss: Array

Alias for field number 0

alpha_loss: Array

Alias for field number 2

critic_grads: FrozenDict

Alias for field number 5

critic_loss: Array

Alias for field number 1

td_error: Array

Alias for field number 3

class arlbench.core.algorithms.sac.SACRunnerState(rng: chex.PRNGKey, actor_train_state: SACTrainState, critic_train_state: SACTrainState, alpha_train_state: SACTrainState, normalizer_state: RunningStatisticsState, env_state: Any, obs: chex.Array, global_step: int)[source]

Bases: NamedTuple

SAC runner state. Consists of (rng, actor_train_state, critic_train_state, alpha_train_state, env_state, obs, global_step).

actor_train_state: SACTrainState

Alias for field number 1

alpha_train_state: SACTrainState

Alias for field number 3

critic_train_state: SACTrainState

Alias for field number 2

env_state: Any

Alias for field number 5

global_step: int

Alias for field number 7

normalizer_state: RunningStatisticsState

Alias for field number 4

obs: Union[Array, ndarray, bool_, number]

Alias for field number 6

rng: Array

Alias for field number 0

class arlbench.core.algorithms.sac.SACState(runner_state: SACRunnerState, buffer_state: PrioritisedTrajectoryBufferState)[source]

Bases: NamedTuple

SAC algorithm state. Consists of (runner_state, buffer_state).

buffer_state: PrioritisedTrajectoryBufferState

Alias for field number 1

runner_state: SACRunnerState

Alias for field number 0

arlbench.core.algorithms.sac.SACTrainReturnT

alias of tuple[SACState, SACTrainingResult]

class arlbench.core.algorithms.sac.SACTrainingResult(eval_rewards: jnp.ndarray, trajectories: Transition | None, metrics: SACMetrics | None)[source]

Bases: NamedTuple

SAC training result. Consists of (eval_rewards, trajectories, metrics).

eval_rewards: Array

Alias for field number 0

metrics: SACMetrics | None

Alias for field number 2

trajectories: Transition | None

Alias for field number 1