Skip to content

Sac

mighty.mighty_models.sac #

SACModel #

SACModel(
    obs_size,
    action_size,
    hidden_sizes=[64, 64],
    activation="relu",
)

Bases: Module

SAC Model with policy and Q-networks.

Source code in mighty/mighty_models/sac.py
def __init__(self, obs_size, action_size, hidden_sizes=[64, 64], activation="relu"):
    super().__init__()
    self.obs_size = obs_size
    self.action_size = action_size

    # Policy network mapping observations to actions
    self.policy_net = nn.Sequential(
        make_feature_extractor(
            architecture="mlp",
            obs_shape=obs_size,
            n_layers=len(hidden_sizes),
            hidden_sizes=hidden_sizes,
            activation=activation,
        )[0],
        nn.Linear(hidden_sizes[-1], 2),
    )

    # Q-networks mapping observation and actions to Q-values
    self.q_net1 = nn.Sequential(
        make_feature_extractor(
            architecture="mlp",
            obs_shape=obs_size + action_size,
            n_layers=len(hidden_sizes),
            hidden_sizes=hidden_sizes,
            activation=activation,
        )[0],
        nn.Linear(hidden_sizes[-1], 1),
    )
    self.q_net2 = nn.Sequential(
        make_feature_extractor(
            architecture="mlp",
            obs_shape=obs_size + action_size,
            n_layers=len(hidden_sizes),
            hidden_sizes=hidden_sizes,
            activation=activation,
        )[0],
        nn.Linear(hidden_sizes[-1], 1),
    )

    # Value network
    self.value_net = nn.Sequential(
        make_feature_extractor(
            architecture="mlp",
            obs_shape=obs_size,
            n_layers=len(hidden_sizes),
            hidden_sizes=hidden_sizes,
            activation=activation,
        )[0],
        nn.Linear(hidden_sizes[-1], 1),
    )