Skip to content

Decaying epsilon greedy

mighty.mighty_exploration.decaying_epsilon_greedy #

Decaying Epsilon‐Greedy Exploration.

DecayingEpsilonGreedy #

DecayingEpsilonGreedy(
    algo,
    model,
    epsilon_start: float = 1.0,
    epsilon_final: float = 0.01,
    epsilon_decay_steps: int = 10000,
)

Bases: EpsilonGreedy

Epsilon-Greedy Exploration with linear decay schedule.

:param epsilon_start: Initial ε (at time step 0) :param epsilon_final: Final ε (after decay_steps) :param epsilon_decay_steps: Number of steps over which to linearly decay ε from epsilon_start → epsilon_final.

Source code in mighty/mighty_exploration/decaying_epsilon_greedy.py
def __init__(
    self,
    algo,
    model,
    epsilon_start: float = 1.0,
    epsilon_final: float = 0.01,
    epsilon_decay_steps: int = 10000,
):
    """
    :param algo:       algorithm name
    :param model:      policy model (e.g. Q-network)
    :param epsilon_start: Initial ε (at time step 0)
    :param epsilon_final: Final ε (after decay_steps)
    :param epsilon_decay_steps: Number of steps over which to linearly
                                 decay ε from epsilon_start → epsilon_final.
    """
    super().__init__(algo=algo, model=model, epsilon=epsilon_start)
    self.epsilon_start = epsilon_start
    self.epsilon_final = epsilon_final
    self.epsilon_decay_steps = epsilon_decay_steps
    self.total_steps = 0

__call__ #

__call__(
    s, return_logp=False, metrics=None, evaluate=False
)

Get action.

:param s: state :param return_logp: return logprobs :param metrics: current metric dict :param eval: eval mode :return: action or (action, logprobs)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def __call__(self, s, return_logp=False, metrics=None, evaluate=False):
    """Get action.

    :param s: state
    :param return_logp: return logprobs
    :param metrics: current metric dict
    :param eval: eval mode
    :return: action or (action, logprobs)
    """
    if metrics is None:
        metrics = {}
    if evaluate:
        action, logprobs = self.sample_action(s)
        output = (action, logprobs) if return_logp else action
    else:
        output = self.explore(s, return_logp, metrics)

    return output

explore #

explore(s, return_logp, metrics=None)

Explore.

:param s: state :param return_logp: return logprobs :param _: not used :return: action or (action, logprobs)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def explore(self, s, return_logp, metrics=None):
    """Explore.

    :param s: state
    :param return_logp: return logprobs
    :param _: not used
    :return: action or (action, logprobs)
    """
    action, logprobs = self.explore_func(s)
    return (action, logprobs) if return_logp else action

explore_func #

explore_func(s)

Same as EpsilonGreedy, except uses decayed ε each time.

Source code in mighty/mighty_exploration/decaying_epsilon_greedy.py
def explore_func(self, s):
    """Same as EpsilonGreedy, except uses decayed ε each time."""
    greedy_actions, qvals = self.sample_action(s)
    exploration_flags, random_actions = self.get_random_actions(
        len(greedy_actions), len(qvals[0])
    )
    actions = np.where(exploration_flags, random_actions, greedy_actions)
    return actions.astype(int), qvals

get_random_actions #

get_random_actions(n_actions, action_length)

Override to recompute ε at each call, then delegate to EpsilonGreedy's logic.

Source code in mighty/mighty_exploration/decaying_epsilon_greedy.py
def get_random_actions(self, n_actions, action_length):
    """
    Override to recompute ε at each call, then delegate to EpsilonGreedy's logic.
    """
    # 1) Update ε based on total_steps
    current_epsilon = self._compute_epsilon()
    self.epsilon = current_epsilon

    # 2) Call parent method to build exploration flags & random actions
    exploration_flags, random_actions = super().get_random_actions(
        n_actions, action_length
    )

    # 3) Advance the step counter (so subsequent calls see a smaller ε)
    self.total_steps += n_actions

    return exploration_flags, random_actions

sample_func_logits #

sample_func_logits(state_array)

state_np: np.ndarray of shape [batch, obs_dim] Returns: (action_tensor, log_prob_tensor)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def sample_func_logits(self, state_array):
    """
    state_np: np.ndarray of shape [batch, obs_dim]
    Returns: (action_tensor, log_prob_tensor)
    """
    state = torch.as_tensor(state_array, dtype=torch.float32)

    # ─── Discrete action branch ─────────────────────────────────────────
    if self.discrete:
        logits = self.model(state)  # [batch, n_actions]
        dist = Categorical(logits=logits)
        action = dist.sample()  # [batch]
        log_prob = dist.log_prob(action)  # [batch]
        return action.detach().cpu().numpy(), log_prob

    # ─── Continuous action branches ─────────────────────────────────────
    out = self.model(state)

    # NEW: Handle 3-tuple (Standard PPO)
    if isinstance(out, tuple) and len(out) == 3:
        action, mean, log_std = out
        std = torch.exp(log_std)
        dist = Normal(mean, std)
        log_prob = dist.log_prob(action).sum(dim=-1)  # Direct log prob
        return action.detach().cpu().numpy(), log_prob

    # ─── Continuous squashed‐Gaussian (4‐tuple) ──────────────────────────
    elif isinstance(out, tuple) and len(out) == 4:
        action = out[0]  # [batch, action_dim]
        log_prob = sample_nondeterministic_logprobs(
            z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac"
        )
        return action.detach().cpu().numpy(), log_prob

    # ─── Legacy continuous branch (model returns (mean, std)) ────────────
    elif isinstance(out, tuple) and len(out) == 2:
        mean, std = out  # both [batch, action_dim]
        dist = Normal(mean, std)
        z = dist.rsample()  # [batch, action_dim]
        action = torch.tanh(z)  # [batch, action_dim]

        # 3a) log_pz = ∑ᵢ log N(zᵢ; μᵢ, σᵢ)
        log_pz = dist.log_prob(z).sum(dim=-1)  # [batch]

        # 3b) tanh‐correction
        eps = 1e-6
        log_correction = torch.log(1.0 - action.pow(2) + eps).sum(dim=-1)  # [batch]

        log_prob = log_pz - log_correction  # [batch]
        return action.detach().cpu().numpy(), log_prob

    # ─── Fallback: if model(state) returns a Distribution ────────────────
    elif isinstance(out, torch.distributions.Distribution):
        dist = out  # user returned a Distribution
        action = dist.sample()  # [batch]
        log_prob = dist.log_prob(action)  # [batch]
        return action.detach().cpu().numpy(), log_prob

    # ─── Otherwise, we don't know how to sample ─────────────────────────
    else:
        raise RuntimeError(
            "MightyExplorationPolicy: cannot interpret model(state) output of type "
            f"{type(out)}"
        )

sample_func_q #

sample_func_q(state_array)
Q-learning branch

• state_np: np.ndarray of shape [batch, obs_dim] • model(state) returns Q-values: tensor [batch, n_actions]

We choose action = argmax(Q), and also return the full Q‐vector.

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def sample_func_q(self, state_array):
    """
    Q-learning branch:
      • state_np: np.ndarray of shape [batch, obs_dim]
      • model(state) returns Q-values: tensor [batch, n_actions]
    We choose action = argmax(Q), and also return the full Q‐vector.
    """
    state = torch.as_tensor(state_array, dtype=torch.float32)
    qs = self.model(state)  # [batch, n_actions]
    # Choose greedy action
    action = torch.argmax(qs, dim=1)  # [batch]
    return action.detach().cpu().numpy(), qs  # action_np, Q‐vector

seed #

seed(seed: int) -> None

Set the random seed for reproducibility.

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def seed(self, seed: int) -> None:
    """Set the random seed for reproducibility."""
    self.rng = np.random.default_rng(seed)