def update(self, batch: MaxiBatch) -> Dict[str, float]:
"""
PPO update continuous branch now evaluates π_new on the *old* latent
actions, which is required for a correct importance-sampling ratio.
Assumes each minibatch has a `latents` Tensor containing the pre-tanh
actions (`z_old`). If you do not store that in your buffer yet, either:
• add it during rollout (recommended), or
• reconstruct it on-the-fly via atanh(squashed_action.clamp(-0.999,0.999)).
"""
# ─────────────────── cache old values & log-probs ───────────────────
with torch.no_grad():
old_values = [
self.model.forward_value(mb.observations) for mb in batch.minibatches
]
old_log_probs = [mb.log_probs.clone() for mb in batch.minibatches]
# ───────────────────── advantage normalisation ──────────────────────
flat_adv = batch.advantages.view(-1)
adv_mean, adv_std = flat_adv.mean(), flat_adv.std() + 1e-8
metrics, mb_updates = (
{
"policy_loss": 0.0,
"value_loss": 0.0,
"entropy": 0.0,
},
0,
)
# ───────────────────────── main PPO loop ────────────────────────────
for epoch in range(self.n_epochs):
epoch_kls = []
for i, mb in enumerate(batch.minibatches):
adv = ((mb.advantages - adv_mean) / adv_std).detach()
# ---- value loss ---------------------------------------------------
values = self.model.forward_value(mb.observations)
if self.use_value_clip:
clipped = old_values[i] + (values - old_values[i]).clamp(
-self.value_clip_eps, self.value_clip_eps
)
v_loss = (
0.5
* torch.max(
(mb.returns - values).pow(2), (mb.returns - clipped).pow(2)
).mean()
)
else:
v_loss = 0.5 * (mb.returns - values).pow(2).mean()
# ---- policy loss (continuous & discrete share the same surr) ----
if self.model.continuous_action:
# Get model output
model_output = self.model(mb.observations)
# NEW: Handle both modes
if hasattr(self.model, "tanh_squash") and self.model.tanh_squash:
# Tanh squashing mode (existing logic)
_, _, mean, log_std = model_output # 4-tuple
dist = torch.distributions.Normal(mean, log_std.exp())
z_old = mb.latents # stored pre-tanh
log_pz = dist.log_prob(z_old).sum(-1)
log_corr = torch.log(1 - torch.tanh(z_old).pow(2) + 1e-6).sum(
-1
)
log_probs = log_pz - log_corr
else:
# Standard PPO mode (new logic)
_, mean, log_std = model_output # 3-tuple
dist = torch.distributions.Normal(mean, log_std.exp())
# Direct log prob on actions (no latents needed)
log_probs = dist.log_prob(mb.actions).sum(-1)
entropy = dist.entropy().sum(-1).mean()
else:
logits = self.model(mb.observations)
dist = torch.distributions.Categorical(logits=logits)
log_probs = dist.log_prob(mb.actions)
entropy = dist.entropy().mean()
ratios = torch.exp(log_probs - old_log_probs[i])
surr1 = ratios * adv
surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * adv
p_loss = -torch.min(surr1, surr2).mean()
# ---- combined loss & optimisation --------------------------------
loss = p_loss + self.vf_coef * v_loss - self.ent_coef * entropy
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.optimizer.step()
# ---- KL divergence on *same* z_old --------------------------------
with torch.no_grad():
if self.model.continuous_action:
model_output_new = self.model(mb.observations)
if (
hasattr(self.model, "tanh_squash")
and self.model.tanh_squash
):
# Tanh squashing mode
_, _, mean_new, log_std_new = model_output_new
dist_new = torch.distributions.Normal(
mean_new, log_std_new.exp()
)
log_pz_new = dist_new.log_prob(z_old).sum(-1)
log_corr_n = torch.log(
1 - torch.tanh(z_old).pow(2) + 1e-6
).sum(-1)
new_lp = log_pz_new - log_corr_n
else:
# Standard PPO mode
_, mean_new, log_std_new = model_output_new
dist_new = torch.distributions.Normal(
mean_new, log_std_new.exp()
)
new_lp = dist_new.log_prob(mb.actions).sum(-1)
else:
logits_new = self.model(mb.observations)
new_lp = torch.distributions.Categorical(
logits=logits_new
).log_prob(mb.actions)
kl = (old_log_probs[i] - new_lp).mean()
epoch_kls.append(kl)
# ---- bookkeeping --------------------------------------------------
metrics["policy_loss"] += p_loss.item()
metrics["value_loss"] += v_loss.item()
metrics["entropy"] += entropy.item()
mb_updates += 1
# ───────── epoch end: LR adaptation, early stop, logging ─────────
if len(epoch_kls) > 0:
mean_kl = torch.stack(epoch_kls).mean()
else:
# If no minibatches were processed, set a default KL value
mean_kl = torch.tensor(0.0)
print("Warning: No minibatches processed in this epoch")
# adaptive LR
if self.adaptive_lr and self.kl_target:
for g in self.optimizer.param_groups[:2]: # policy & value groups
if mean_kl > 1.5 * self.kl_target:
g["lr"] = max(g["lr"] * 0.8, self.min_lr)
elif mean_kl < 0.5 * self.kl_target and epoch == 0:
g["lr"] = min(
g["lr"] * 1.1,
(
self.initial_policy_lr
if g is self.optimizer.param_groups[0]
else self.initial_value_lr
),
)
# early-stop if KL already large
if mean_kl > self.kl_target:
break
# Scheduler AFTER adaptive block (won’t drive LR below 0.1× init)
self.scheduler.step()
# clip param-group LR to min_lr
for g in self.optimizer.param_groups:
if g["lr"] < self.min_lr:
g["lr"] = self.min_lr
# final averaged metrics
for k in metrics:
metrics[k] /= mb_updates if mb_updates > 0 else 1
metrics["approx_kl"] = mean_kl.item() # final KL of the run
return {
"Update/policy_loss": metrics["policy_loss"],
"Update/value_loss": metrics["value_loss"],
"Update/entropy": metrics["entropy"],
"Update/approx_kl": metrics["approx_kl"],
}