arlbench.core.algorithms.ppo.models

Models for the PPO algorithm.

Classes

CNNActorCritic(action_dim[, activation, ...])

A CNN-based Actor-Critic network for PPO.

MLPActorCritic(action_dim[, activation, ...])

An MLP-based Actor-Critic network for PPO.

class arlbench.core.algorithms.ppo.models.CNNActorCritic(action_dim, activation='tanh', hidden_size=512, discrete=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A CNN-based Actor-Critic network for PPO. Based on NatureCNN https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/torch_layers.py#L48.

__call__(x)[source]

Applies the actor-critic to the input.

setup()[source]

Initializes the actor-critic network.

class arlbench.core.algorithms.ppo.models.MLPActorCritic(action_dim, activation='tanh', hidden_size=64, discrete=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

An MLP-based Actor-Critic network for PPO.

__call__(x)[source]

Applies the actor-critic to the input.

setup()[source]

Initializes the actor-critic network.