arlbench.core.algorithms.sac.models

SAC models for the actor and critic networks.

Classes

AlphaCoef([alpha_init, parent, name])

Alpha coefficient for SAC.

SACCNNActor(action_dim, activation[, ...])

A CNN-based actor network for SAC.

SACCNNCritic(action_dim, activation[, ...])

A CNN-based critic network for SAC.

SACMLPActor(action_dim, activation[, ...])

An MLP-based actor network for PPO.

SACMLPCritic(action_dim, activation[, ...])

An MLP-based critic network for SAC.

SACVectorCritic(critic, action_dim, activation)

A vectorized critic network for SAC.

TanhTransformedDistribution(*args, **kwargs)

Tanh transformation of a distrax distribution.

class arlbench.core.algorithms.sac.models.AlphaCoef(alpha_init=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Alpha coefficient for SAC.

__call__()[source]

Returns the alpha coefficient.

Return type:

Array

setup()[source]

Initializes the alpha coefficient.

class arlbench.core.algorithms.sac.models.SACCNNActor(action_dim, activation, hidden_size=64, log_std_min=-20, log_std_max=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A CNN-based actor network for SAC. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48.

__call__(x)[source]

Applies the actor to the input.

setup()[source]

Initializes the actor network.

class arlbench.core.algorithms.sac.models.SACCNNCritic(action_dim, activation, hidden_size=512, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A CNN-based critic network for SAC. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48.

__call__(x, action)[source]

Applies the critic to the input.

setup()[source]

Initializes the critic network.

class arlbench.core.algorithms.sac.models.SACMLPActor(action_dim, activation, hidden_size=64, log_std_min=-20, log_std_max=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

An MLP-based actor network for PPO.

__call__(x)[source]

Applies the actor to the input.

setup()[source]

Initializes the actor network.

class arlbench.core.algorithms.sac.models.SACMLPCritic(action_dim, activation, hidden_size=64, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

An MLP-based critic network for SAC.

__call__(x, action)[source]

Applies the critic to the input.

setup()[source]

Initializes the critic network.

class arlbench.core.algorithms.sac.models.SACVectorCritic(critic, action_dim, activation, hidden_size=64, n_critics=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A vectorized critic network for SAC.

__call__(x, action)[source]

Applies the critic to the input.

class arlbench.core.algorithms.sac.models.TanhTransformedDistribution(*args, **kwargs)[source]

Bases: Transformed

Tanh transformation of a distrax distribution.

mode()[source]

Returns the mode of the distribution.

Return type:

Array