def __init__(
self,
observations: np.ndarray,
actions: np.ndarray,
*,
latents: np.ndarray | None = None, # ← NEW (optional for discrete)
rewards: np.ndarray,
advantages: np.ndarray,
returns: np.ndarray,
episode_starts: np.ndarray,
log_probs: np.ndarray,
values: np.ndarray,
device: torch.device | str = "cpu",
):
self.device = device
obs_t = torch.from_numpy(observations.astype(np.float32))
act_t = torch.from_numpy(actions.astype(np.float32))
lat_t = (
torch.from_numpy(latents.astype(np.float32))
if latents is not None
else None
) # may stay None for discrete
rew_t = torch.from_numpy(rewards.astype(np.float32))
adv_t = torch.from_numpy(advantages.astype(np.float32))
ret_t = torch.from_numpy(returns.astype(np.float32))
eps_t = torch.from_numpy(episode_starts.astype(np.float32))
logp_t = torch.from_numpy(log_probs.astype(np.float32))
val_t = torch.from_numpy(values.astype(np.float32))
# Promote obs from [n_envs, obs_dim] → [1, n_envs, obs_dim] if needed
if obs_t.dim() == 2:
obs_t = obs_t.unsqueeze(0)
elif obs_t.dim() < 2:
raise RuntimeError(
f"RolloutBatch: `observations` must be ≥2‑D, got {obs_t.shape}"
)
def _promote(x: torch.Tensor | None, name: str):
if x is None:
return None
if x.dim() == 1: # (n_envs,) → (1, n_envs)
return x.unsqueeze(0)
elif x.dim() == 2: # (timesteps, n_envs) - already correct
return x
elif x.dim() == 3 and name in [
"actions",
"observations",
]: # (timesteps, n_envs, features)
return x
else:
raise RuntimeError(f"Unexpected shape for {name}: {x.shape}")
act_t = _promote(act_t, "actions")
lat_t = _promote(lat_t, "latents")
rew_t = _promote(rew_t, "rewards")
adv_t = _promote(adv_t, "advantages")
ret_t = _promote(ret_t, "returns")
eps_t = _promote(eps_t, "episode_starts")
logp_t = _promote(logp_t, "log_probs")
val_t = _promote(val_t, "values")
# Move to device ---------------------------------------------------
self.observations = obs_t.to(self.device)
self.actions = act_t.to(self.device)
self.latents = lat_t.to(self.device) if lat_t is not None else None
self.rewards = rew_t.to(self.device)
self.advantages = adv_t.to(self.device)
self.returns = ret_t.to(self.device)
self.episode_starts = eps_t.to(self.device)
self.log_probs = logp_t.to(self.device)
self.values = val_t.to(self.device)