Skip to content

Sac

mighty.mighty_models.sac #

SACModel #

SACModel(
    obs_size: int,
    action_size: int,
    hidden_sizes: list[int] = [64, 64],
    activation: str = "relu",
    log_std_min: float = -20,
    log_std_max: float = 2,
)

Bases: Module

SAC Model with squashed Gaussian policy and twin Q-networks.

Source code in mighty/mighty_models/sac.py
def __init__(
    self,
    obs_size: int,
    action_size: int,
    hidden_sizes: list[int] = [64, 64],
    activation: str = "relu",
    log_std_min: float = -20,
    log_std_max: float = 2,
):
    super().__init__()
    self.obs_size = obs_size
    self.action_size = action_size
    self.log_std_min = log_std_min
    self.log_std_max = log_std_max
    self.hidden_sizes = hidden_sizes
    self.activation = activation

    # Shared feature extractor for policy and Q-networks
    extractor, out_dim = make_feature_extractor(
        architecture="mlp",
        obs_shape=obs_size,
        n_layers=len(hidden_sizes),
        hidden_sizes=hidden_sizes,
        activation=activation,
    )

    # Policy network outputs mean and log_std
    self.policy_net = nn.Sequential(
        extractor,
        nn.Linear(out_dim, action_size * 2),
    )

    # Twin Q-networks
    # — live Q-nets —
    self.q_net1 = self._make_q_net()
    self.q_net2 = self._make_q_net()

    self.target_q_net1 = self._make_q_net()
    self.target_q_net1.load_state_dict(self.q_net1.state_dict())
    self.target_q_net2 = self._make_q_net()
    self.target_q_net2.load_state_dict(self.q_net2.state_dict())
    for p in self.target_q_net1.parameters():
        p.requires_grad = False
    for p in self.target_q_net2.parameters():
        p.requires_grad = False

forward #

forward(
    state: Tensor, deterministic: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor]

Forward pass for policy sampling.

RETURNS DESCRIPTION
action

torch.Tensor in [-1,1] z: raw Gaussian sample before tanh mean: Gaussian mean log_std: Gaussian log std

TYPE: Tuple[Tensor, Tensor, Tensor, Tensor]

Source code in mighty/mighty_models/sac.py
def forward(self, 
    state: torch.Tensor, 
    deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Forward pass for policy sampling.

    Returns:
      action: torch.Tensor in [-1,1]
      z: raw Gaussian sample before tanh
      mean: Gaussian mean
      log_std: Gaussian log std
    """
    x = self.policy_net(state)
    mean, log_std = x.chunk(2, dim=-1)
    log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
    std = torch.exp(log_std)

    if deterministic:
        z = mean
    else:
        z = mean + std * torch.randn_like(mean)
    action = torch.tanh(z)
    return action, z, mean, log_std

policy_log_prob #

policy_log_prob(
    z: Tensor, mean: Tensor, log_std: Tensor
) -> Tensor

Compute log-prob of action a = tanh(z), correcting for tanh transform.

Source code in mighty/mighty_models/sac.py
def policy_log_prob(
    self, z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor
) -> torch.Tensor:
    """
    Compute log-prob of action a = tanh(z), correcting for tanh transform.
    """
    std = torch.exp(log_std)
    dist = torch.distributions.Normal(mean, std)
    log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True)
    # Tanh correction term
    log_pa = log_pz - (2 * (z - torch.tanh(z).abs().log1p())).sum(
        dim=-1, keepdim=True
    )
    return log_pa