def update(self, batch: RolloutBatch) -> Dict[str, float]:
"""Update the PPO model."""
states = batch.observations.squeeze(0)
actions = batch.actions.squeeze(0)
old_log_probs = batch.log_probs.squeeze(0)
advantages = batch.advantages.squeeze(0)
returns = batch.returns.squeeze(0)
# Compute the value loss
values = self.model.forward_value(states).squeeze()
value_loss = torch.nn.functional.mse_loss(returns, values)
# Normalize advantage does not make sense if mini batchsize == 1, see GH issue #325
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Compute the policy loss
if self.model.continuous_action:
means, stds = self.model(states)
dist = torch.distributions.Normal(means, stds)
log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)
entropy = dist.entropy().sum(dim=-1, keepdim=True)
else:
logits = self.model(states)
dist = torch.distributions.Categorical(logits=logits) # type: ignore
log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)
entropy = dist.entropy().sum(dim=-1, keepdim=True)
ratios = torch.exp(log_probs - old_log_probs)
if len(ratios.shape) > 2:
ratios = ratios.squeeze(-1)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Total loss
loss = policy_loss + self.vf_coef * value_loss - self.ent_coef * entropy.mean()
# Update the policy network
self.policy_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.policy_net.parameters(), self.max_grad_norm
)
self.policy_optimizer.step()
return {
"Update/policy_loss": policy_loss.item(),
"Update/value_loss": value_loss.item(),
"Update/entropy": entropy.mean().item(),
}