Skip to content

Stochastic policy

mighty.mighty_exploration.stochastic_policy #

StochasticPolicy #

StochasticPolicy(
    algo,
    model,
    entropy_coefficient: float = 0.2,
    discrete: bool = True,
)

Bases: MightyExplorationPolicy

Entropy-Based Exploration for discrete and continuous action spaces.

:param entropy_coefficient: weight on entropy term :param discrete: whether the action space is discrete

Source code in mighty/mighty_exploration/stochastic_policy.py
def __init__(
    self, algo, model, entropy_coefficient: float = 0.2, discrete: bool = True
):
    """
    :param algo: the RL algorithm instance
    :param model: the policy model
    :param entropy_coefficient: weight on entropy term
    :param discrete: whether the action space is discrete
    """
    super().__init__(algo, model, discrete)
    self.entropy_coefficient = entropy_coefficient
    self.discrete = discrete

    # --- override sample_action only for continuous SAC ---
    if not discrete and isinstance(model, SACModel):
        # for evaluation use deterministic=True; training will go through .explore()
        def _sac_sample(state_np):
            state = torch.as_tensor(state_np, dtype=torch.float32)
            # forward returns (action, z, mean, log_std)
            action, z, mean, log_std = model(state, deterministic=True)
            logp = model.policy_log_prob(z, mean, log_std)
            return action, logp

        self.sample_action = _sac_sample

__call__ #

__call__(
    s, return_logp=False, metrics=None, evaluate=False
)

Get action.

:param s: state :param return_logp: return logprobs :param metrics: current metric dict :param eval: eval mode :return: action or (action, logprobs)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def __call__(self, s, return_logp=False, metrics=None, evaluate=False):
    """Get action.

    :param s: state
    :param return_logp: return logprobs
    :param metrics: current metric dict
    :param eval: eval mode
    :return: action or (action, logprobs)
    """
    if metrics is None:
        metrics = {}
    if evaluate:
        action, logprobs = self.sample_action(s)
        action = action.detach().numpy()
        output = (action, logprobs) if return_logp else action
    else:
        output = self.explore(s, return_logp, metrics)

    return output

explore #

explore(
    s, return_logp, metrics=None
) -> Tuple[ndarray, Tensor]

Given observations s, sample an exploratory action and compute a weighted log-prob.

RETURNS DESCRIPTION
action

numpy array of actions weighted_log_prob: Tensor of shape [batch, 1]

TYPE: Tuple[ndarray, Tensor]

Source code in mighty/mighty_exploration/stochastic_policy.py
def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tensor]:
    """
    Given observations `s`, sample an exploratory action and compute a weighted log-prob.

    Returns:
      action: numpy array of actions
      weighted_log_prob: Tensor of shape [batch, 1]
    """
    state = torch.as_tensor(s, dtype=torch.float32)
    if self.discrete:
        logits = self.model(state)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action).unsqueeze(-1)
    else:
        # Model returns: action (tanh-squashed), z (pre-tanh), mean, log_std
        action, z, mean, log_std = self.model(state)
        std = torch.exp(log_std)
        dist = torch.distributions.Normal(mean, std)
        # log_prob and entropy over pre-tanh sample
        log_prob = dist.log_prob(z).sum(dim=-1, keepdim=True)
    # Weighted by entropy coefficient
    weighted_log_prob = log_prob * self.entropy_coefficient
    return action.detach().cpu().numpy(), weighted_log_prob

explore_func #

explore_func(s)

Explore function.

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def explore_func(self, s):
    """Explore function."""
    raise NotImplementedError

forward #

forward(s)

Alias for explore, so policy(s) returns (action, weighted_log_prob).

Source code in mighty/mighty_exploration/stochastic_policy.py
def forward(self, s):
    """
    Alias for explore, so policy(s) returns (action, weighted_log_prob).
    """
    return self.explore(s)