Skip to content

Stochastic policy

mighty.mighty_exploration.stochastic_policy #

StochasticPolicy #

StochasticPolicy(
    algo, model, entropy_coefficient=0.2, discrete=True
)

Bases: MightyExplorationPolicy

Entropy Based Exploration.

:param algo: algorithm name :param model: policy model :param entropy_coefficient: entropy coefficient :return:

Source code in mighty/mighty_exploration/stochastic_policy.py
def __init__(self, algo, model, entropy_coefficient=0.2, discrete=True):
    """Initialize Entropy Based Exploration.

    :param algo: algorithm name
    :param model: policy model
    :param entropy_coefficient: entropy coefficient
    :return:
    """
    super().__init__(algo, model, discrete)
    self.entropy_coefficient = entropy_coefficient

    # FIXME: I did this already for the other exploration functions, but this would be nicer as a separate function
    def explore_func(s):
        state = torch.FloatTensor(s)  # Add batch dimension if needed
        if discrete:
            logits = self.model(state)
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            weighted_log_prob = log_prob * entropy
        else:
            mean, std = self.model(state)
            dist = torch.distributions.Normal(mean, std)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)
            entropy = dist.entropy().sum(dim=-1, keepdim=True)
            weighted_log_prob = log_prob * entropy

        return action.detach().numpy(), weighted_log_prob.detach().numpy()

    self.explore_func = explore_func

__call__ #

__call__(
    state, return_logp=False, metrics=None, evaluate=False
)

Get action.

:param s: state :param return_logp: return logprobs :param metrics: current metric dict :param eval: eval mode :return: action or (action, logprobs)

Source code in mighty/mighty_exploration/stochastic_policy.py
def __call__(self, state, return_logp=False, metrics=None, evaluate=False):
    """Get action.

    :param s: state
    :param return_logp: return logprobs
    :param metrics: current metric dict
    :param eval: eval mode
    :return: action or (action, logprobs)
    """
    if metrics is None:
        metrics = {}
    if evaluate:
        action, logprobs = self.sample_action(state)
        action = action.detach().numpy()
        output = (action, logprobs.detach.numpy()) if return_logp else action
    else:
        output = self.explore(state, return_logp, metrics)

    return output

explore #

explore(s, return_logp, _)

Explore.

:param s: state :param return_logp: return logprobs :param _: not used :return: action or (action, logprobs)

Source code in mighty/mighty_exploration/mighty_exploration_policy.py
def explore(self, s, return_logp, _):
    """Explore.

    :param s: state
    :param return_logp: return logprobs
    :param _: not used
    :return: action or (action, logprobs)
    """
    action, logprobs = self.explore_func(s)
    return (action, logprobs) if return_logp else action