Skip to content

Mighty rollout buffer

mighty.mighty_replay.mighty_rollout_buffer #

MightyRolloutBuffer #

MightyRolloutBuffer(
    buffer_size: int,
    obs_shape,
    act_dim,
    device: device | str = "cpu",
    *,
    gae_lambda: float = 1.0,
    gamma: float = 0.99,
    n_envs: int = 1,
    discrete_action: bool = False,
    use_latents: bool = False,
)

Bases: MightyBuffer

Pre-allocated rollout buffer (no repeated concat).

Source code in mighty/mighty_replay/mighty_rollout_buffer.py
def __init__(
    self,
    buffer_size: int,
    obs_shape,
    act_dim,
    device: torch.device | str = "cpu",
    *,
    gae_lambda: float = 1.0,
    gamma: float = 0.99,
    n_envs: int = 1,
    discrete_action: bool = False,
    use_latents: bool = False,
):
    super().__init__()
    self.buffer_size = buffer_size
    self.n_envs = n_envs
    self.device = device
    self.gamma = gamma
    self.gae_lambda = gae_lambda
    self.discrete_action = discrete_action
    self.use_latents = use_latents  # Store for later use
    self.rng = np.random.default_rng()

    # Shapes -----------------------------------------------------------
    if isinstance(obs_shape, int):
        obs_shape = (obs_shape,)

    def zeros(shape):
        return torch.zeros(shape, dtype=torch.float32, device=device)

    self.observations = zeros((buffer_size, n_envs, *obs_shape))

    if discrete_action:
        self.actions = zeros((buffer_size, n_envs))
        self.latents = None  # not used
    else:
        self.actions = zeros((buffer_size, n_envs, act_dim))

        if use_latents:
            self.latents = zeros((buffer_size, n_envs, act_dim))
        else:
            self.latents = None

    self.rewards = zeros((buffer_size, n_envs))
    self.advantages = zeros((buffer_size, n_envs))
    self.returns = zeros((buffer_size, n_envs))
    self.episode_starts = zeros((buffer_size, n_envs))
    self.log_probs = zeros((buffer_size, n_envs))
    self.values = zeros((buffer_size, n_envs))

    self.pos = 0

seed #

seed(seed: int)

Set random seed.

Source code in mighty/mighty_replay/buffer.py
def seed(self, seed: int):
    """Set random seed."""
    self.rng = np.random.default_rng(seed)

RolloutBatch #

RolloutBatch(
    observations: ndarray,
    actions: ndarray,
    *,
    latents: ndarray | None = None,
    rewards: ndarray,
    advantages: ndarray,
    returns: ndarray,
    episode_starts: ndarray,
    log_probs: ndarray,
    values: ndarray,
    device: device | str = "cpu",
)

A contiguous slice of experience – now stores the latent z too.

Source code in mighty/mighty_replay/mighty_rollout_buffer.py
def __init__(
    self,
    observations: np.ndarray,
    actions: np.ndarray,
    *,
    latents: np.ndarray | None = None,  # ← NEW (optional for discrete)
    rewards: np.ndarray,
    advantages: np.ndarray,
    returns: np.ndarray,
    episode_starts: np.ndarray,
    log_probs: np.ndarray,
    values: np.ndarray,
    device: torch.device | str = "cpu",
):
    self.device = device

    obs_t = torch.from_numpy(observations.astype(np.float32))
    act_t = torch.from_numpy(actions.astype(np.float32))
    lat_t = (
        torch.from_numpy(latents.astype(np.float32))
        if latents is not None
        else None
    )  # may stay None for discrete
    rew_t = torch.from_numpy(rewards.astype(np.float32))
    adv_t = torch.from_numpy(advantages.astype(np.float32))
    ret_t = torch.from_numpy(returns.astype(np.float32))
    eps_t = torch.from_numpy(episode_starts.astype(np.float32))
    logp_t = torch.from_numpy(log_probs.astype(np.float32))
    val_t = torch.from_numpy(values.astype(np.float32))

    # Promote obs from [n_envs, obs_dim] → [1, n_envs, obs_dim] if needed
    if obs_t.dim() == 2:
        obs_t = obs_t.unsqueeze(0)
    elif obs_t.dim() < 2:
        raise RuntimeError(
            f"RolloutBatch: `observations` must be ≥2‑D, got {obs_t.shape}"
        )

    def _promote(x: torch.Tensor | None, name: str):
        if x is None:
            return None
        if x.dim() == 1:  # (n_envs,) → (1, n_envs)
            return x.unsqueeze(0)
        elif x.dim() == 2:  # (timesteps, n_envs) - already correct
            return x
        elif x.dim() == 3 and name in [
            "actions",
            "observations",
        ]:  # (timesteps, n_envs, features)
            return x
        else:
            raise RuntimeError(f"Unexpected shape for {name}: {x.shape}")

    act_t = _promote(act_t, "actions")
    lat_t = _promote(lat_t, "latents")
    rew_t = _promote(rew_t, "rewards")
    adv_t = _promote(adv_t, "advantages")
    ret_t = _promote(ret_t, "returns")
    eps_t = _promote(eps_t, "episode_starts")
    logp_t = _promote(logp_t, "log_probs")
    val_t = _promote(val_t, "values")

    # Move to device ---------------------------------------------------
    self.observations = obs_t.to(self.device)
    self.actions = act_t.to(self.device)
    self.latents = lat_t.to(self.device) if lat_t is not None else None
    self.rewards = rew_t.to(self.device)
    self.advantages = adv_t.to(self.device)
    self.returns = ret_t.to(self.device)
    self.episode_starts = eps_t.to(self.device)
    self.log_probs = logp_t.to(self.device)
    self.values = val_t.to(self.device)