Skip to content

Stochastic policy

mighty.mighty_exploration.stochastic_policy #

Stochastic Policy for Entropy-Based Exploration.

StochasticPolicy #

StochasticPolicy(
    algo,
    model,
    entropy_coefficient: float = 0.2,
    discrete: bool = True,
)

Bases: MightyExplorationPolicy

Entropy-Based Exploration for discrete and continuous action spaces.

:param entropy_coefficient: weight on entropy term :param discrete: whether the action space is discrete

Source code in mighty/mighty_exploration/stochastic_policy.py
def __init__(
    self, algo, model, entropy_coefficient: float = 0.2, discrete: bool = True
):
    """
    :param algo: the RL algorithm instance
    :param model: the policy model
    :param entropy_coefficient: weight on entropy term
    :param discrete: whether the action space is discrete
    """

    self.model = model

    super().__init__(algo, model, discrete)
    self.entropy_coefficient = entropy_coefficient
    self.discrete = discrete

    # --- override sample_action only for continuous SAC ---
    if not discrete and isinstance(model, SACModel):
        # for evaluation use deterministic=True; training will go through .explore()
        def _sac_sample(state_np):
            state = torch.as_tensor(state_np, dtype=torch.float32)
            # forward returns (action, z, mean, log_std)
            action, z, mean, log_std = model(state, deterministic=True)
            logp = model.policy_log_prob(z, mean, log_std)
            return action.detach().cpu().numpy(), logp

        self.sample_action = _sac_sample

__call__ #

__call__(
    s, return_logp=False, metrics=None, evaluate=False
)

Get action.

:param s: state :param return_logp: return logprobs :param metrics: current metric dict :param eval: eval mode :return: action or (action, logprobs)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def __call__(self, s, return_logp=False, metrics=None, evaluate=False):
    """Get action.

    :param s: state
    :param return_logp: return logprobs
    :param metrics: current metric dict
    :param eval: eval mode
    :return: action or (action, logprobs)
    """
    if metrics is None:
        metrics = {}
    if evaluate:
        action, logprobs = self.sample_action(s)
        output = (action, logprobs) if return_logp else action
    else:
        output = self.explore(s, return_logp, metrics)

    return output

explore #

explore(
    s, return_logp, metrics=None
) -> Tuple[ndarray, Tensor]

Given observations s, sample an exploratory action and compute a weighted log-prob.

RETURNS DESCRIPTION
action

numpy array of actions weighted_log_prob: Tensor of shape [batch, 1]

TYPE: Tuple[ndarray, Tensor]

Source code in mighty/mighty_exploration/stochastic_policy.py
def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tensor]:
    """
    Given observations `s`, sample an exploratory action and compute a weighted log-prob.

    Returns:
      action: numpy array of actions
      weighted_log_prob: Tensor of shape [batch, 1]
    """
    state = torch.as_tensor(s, dtype=torch.float32)

    if self.discrete:
        logits = self.model(state)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action).unsqueeze(-1)
        return action.detach().cpu().numpy(), log_prob * self.entropy_coefficient

    else:
        # Get model output
        model_output = self.model(state)

        # Handle different model output formats

        # NEW: 3-tuple case (Standard PPO): (action, mean, log_std)
        if isinstance(model_output, tuple) and len(model_output) == 3:
            action, mean, log_std = model_output
            std = torch.exp(log_std)
            dist = Normal(mean, std)

            # Direct log prob (no tanh correction)
            log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)

            if return_logp:
                return action.detach().cpu().numpy(), log_prob
            else:
                weighted_log_prob = log_prob * self.entropy_coefficient
                return action.detach().cpu().numpy(), weighted_log_prob

        # 4-tuple case (Tanh squashing): (action, z, mean, log_std)
        elif isinstance(model_output, tuple) and len(model_output) == 4:
            action, z, mean, log_std = model_output

            if not self.algo == "sac":

                log_prob = sample_nondeterministic_logprobs(
                    z=z,
                    mean=mean,
                    log_std=log_std,
                    sac=False,
                )
            else:
                log_prob = self.model.policy_log_prob(z, mean, log_std)

            if return_logp:
                return action.detach().cpu().numpy(), log_prob
            else:
                weighted_log_prob = log_prob * self.entropy_coefficient
                return action.detach().cpu().numpy(), weighted_log_prob

        # Check for model attribute-based approaches
        elif hasattr(self.model, "continuous_action") and getattr(
            self.model, "continuous_action"
        ):
            # This handles the case where model has continuous_action attribute
            # but we need to determine the output format dynamically
            if len(model_output) == 3:
                # Standard PPO mode: (action, mean, log_std)
                action, mean, log_std = model_output
                std = torch.exp(log_std)
                dist = Normal(mean, std)
                log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)
            elif len(model_output) == 4:
                # Tanh squashing mode: (action, z, mean, log_std)
                action, z, mean, log_std = model_output
                if not self.algo == "sac":

                    log_prob = sample_nondeterministic_logprobs(
                        z=z,
                        mean=mean,
                        log_std=log_std,
                        sac=False,
                    )
                else:
                    log_prob = self.model.policy_log_prob(z, mean, log_std)
            else:
                raise ValueError(
                    f"Unexpected model output length: {len(model_output)}"
                )

            if return_logp:
                return action.detach().cpu().numpy(), log_prob
            else:
                weighted_log_prob = log_prob * self.entropy_coefficient
                return action.detach().cpu().numpy(), weighted_log_prob

        # Check for output_style attribute (backwards compatibility)
        elif hasattr(self.model, "output_style"):
            if self.model.output_style == "squashed_gaussian":
                # Should be 4-tuple: (action, z, mean, log_std)
                action, z, mean, log_std = model_output
                if not self.algo == "sac":
                    log_prob = sample_nondeterministic_logprobs(
                        z=z,
                        mean=mean,
                        log_std=log_std,
                        sac=False,
                    )
                else:
                    log_prob = self.model.policy_log_prob(z, mean, log_std)

                if return_logp:
                    return action.detach().cpu().numpy(), log_prob
                else:
                    weighted_log_prob = log_prob * self.entropy_coefficient
                    return action.detach().cpu().numpy(), weighted_log_prob

            elif self.model.output_style == "mean_std":
                # Should be 2-tuple: (mean, std)
                mean, std = model_output
                dist = Normal(mean, std)
                z = dist.rsample()
                action = torch.tanh(z)

                if not self.algo == "sac":
                    log_prob = sample_nondeterministic_logprobs(
                        z=z,
                        mean=mean,
                        log_std=log_std,
                        sac=False,
                    )
                else:
                    log_prob = self.model.policy_log_prob(z, mean, log_std)

                entropy = dist.entropy().sum(dim=-1, keepdim=True)
                weighted_log_prob = log_prob * entropy
                return action.detach().cpu().numpy(), weighted_log_prob

            else:
                raise RuntimeError(
                    f"StochasticPolicy: unknown output_style '{self.model.output_style}'"
                )

        # Special handling for SACModel
        elif self.algo == "sac" and isinstance(self.model, SACModel):
            action, z, mean, log_std = self.model(state, deterministic=False)
            # CRITICAL: Use the model's policy_log_prob which includes tanh correction
            log_prob = self.model.policy_log_prob(z, mean, log_std)
            return action.detach().cpu().numpy(), log_prob

        else:
            raise RuntimeError(
                "StochasticPolicy: cannot interpret model(state) output of type "
                f"{type(model_output)} with length {len(model_output) if isinstance(model_output, tuple) else 'N/A'}"
            )

explore_func #

explore_func(s)

Explore function.

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def explore_func(self, s):
    """Explore function."""
    raise NotImplementedError

forward #

forward(s)

Alias for explore, so policy(s) returns (action, weighted_log_prob).

Source code in mighty/mighty_exploration/stochastic_policy.py
def forward(self, s):
    """
    Alias for explore, so policy(s) returns (action, weighted_log_prob).
    """
    return self.explore(s, return_logp=False)

sample_func_logits #

sample_func_logits(state_array)

state_np: np.ndarray of shape [batch, obs_dim] Returns: (action_tensor, log_prob_tensor)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def sample_func_logits(self, state_array):
    """
    state_np: np.ndarray of shape [batch, obs_dim]
    Returns: (action_tensor, log_prob_tensor)
    """
    state = torch.as_tensor(state_array, dtype=torch.float32)

    # ─── Discrete action branch ─────────────────────────────────────────
    if self.discrete:
        logits = self.model(state)  # [batch, n_actions]
        dist = Categorical(logits=logits)
        action = dist.sample()  # [batch]
        log_prob = dist.log_prob(action)  # [batch]
        return action.detach().cpu().numpy(), log_prob

    # ─── Continuous action branches ─────────────────────────────────────
    out = self.model(state)

    # NEW: Handle 3-tuple (Standard PPO)
    if isinstance(out, tuple) and len(out) == 3:
        action, mean, log_std = out
        std = torch.exp(log_std)
        dist = Normal(mean, std)
        log_prob = dist.log_prob(action).sum(dim=-1)  # Direct log prob
        return action.detach().cpu().numpy(), log_prob

    # ─── Continuous squashed‐Gaussian (4‐tuple) ──────────────────────────
    elif isinstance(out, tuple) and len(out) == 4:
        action = out[0]  # [batch, action_dim]
        log_prob = sample_nondeterministic_logprobs(
            z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac"
        )
        return action.detach().cpu().numpy(), log_prob

    # ─── Legacy continuous branch (model returns (mean, std)) ────────────
    elif isinstance(out, tuple) and len(out) == 2:
        mean, std = out  # both [batch, action_dim]
        dist = Normal(mean, std)
        z = dist.rsample()  # [batch, action_dim]
        action = torch.tanh(z)  # [batch, action_dim]

        # 3a) log_pz = ∑ᵢ log N(zᵢ; μᵢ, σᵢ)
        log_pz = dist.log_prob(z).sum(dim=-1)  # [batch]

        # 3b) tanh‐correction
        eps = 1e-6
        log_correction = torch.log(1.0 - action.pow(2) + eps).sum(dim=-1)  # [batch]

        log_prob = log_pz - log_correction  # [batch]
        return action.detach().cpu().numpy(), log_prob

    # ─── Fallback: if model(state) returns a Distribution ────────────────
    elif isinstance(out, torch.distributions.Distribution):
        dist = out  # user returned a Distribution
        action = dist.sample()  # [batch]
        log_prob = dist.log_prob(action)  # [batch]
        return action.detach().cpu().numpy(), log_prob

    # ─── Otherwise, we don't know how to sample ─────────────────────────
    else:
        raise RuntimeError(
            "MightyExplorationPolicy: cannot interpret model(state) output of type "
            f"{type(out)}"
        )

sample_func_q #

sample_func_q(state_array)
Q-learning branch

• state_np: np.ndarray of shape [batch, obs_dim] • model(state) returns Q-values: tensor [batch, n_actions]

We choose action = argmax(Q), and also return the full Q‐vector.

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def sample_func_q(self, state_array):
    """
    Q-learning branch:
      • state_np: np.ndarray of shape [batch, obs_dim]
      • model(state) returns Q-values: tensor [batch, n_actions]
    We choose action = argmax(Q), and also return the full Q‐vector.
    """
    state = torch.as_tensor(state_array, dtype=torch.float32)
    qs = self.model(state)  # [batch, n_actions]
    # Choose greedy action
    action = torch.argmax(qs, dim=1)  # [batch]
    return action.detach().cpu().numpy(), qs  # action_np, Q‐vector

seed #

seed(seed: int) -> None

Set the random seed for reproducibility.

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def seed(self, seed: int) -> None:
    """Set the random seed for reproducibility."""
    self.rng = np.random.default_rng(seed)