Skip to content

Seed snapshot

neps.state.seed_snapshot #

Snapshot of the global rng state.

SeedSnapshot dataclass #

SeedSnapshot(
    np_rng: NP_RNG_STATE,
    py_rng: PY_RNG_STATE,
    torch_rng: TORCH_RNG_STATE | None,
    torch_cuda_rng: TORCH_CUDA_RNG_STATE | None,
)

State of the global rng.

Primarly enables storing of the rng state to disk using a binary format native to each library, allowing for potential version mistmatches between processes loading the state, as long as they can read the binary format.

new_capture classmethod #

new_capture() -> SeedSnapshot

Current state of the global rng.

Takes a snapshot, including cloning or copying any arrays, tensors, etc.

Source code in neps/state/seed_snapshot.py
@classmethod
def new_capture(cls) -> SeedSnapshot:
    """Current state of the global rng.

    Takes a snapshot, including cloning or copying any arrays, tensors, etc.
    """
    self = cls(None, None, None, None)  # type: ignore
    self.recapture()
    return self

recapture #

recapture() -> None

Reread the state of the global rng into this snapshot.

Source code in neps/state/seed_snapshot.py
def recapture(self) -> None:
    """Reread the state of the global rng into this snapshot."""
    # https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html

    self.py_rng = random.getstate()

    np_keys = np.random.get_state(legacy=True)
    assert np_keys[0] == "MT19937"  # type: ignore
    self.np_rng = (np_keys[0], np_keys[1].copy(), *np_keys[2:])  # type: ignore

    with contextlib.suppress(Exception):
        import torch

        self.torch_rng = torch.random.get_rng_state().clone()
        torch_cuda_keys: list[torch.Tensor] | None = None
        if torch.cuda.is_available():
            torch_cuda_keys = [c.clone() for c in torch.cuda.get_rng_state_all()]
        self.torch_cuda_rng = torch_cuda_keys

set_as_global_seed_state #

set_as_global_seed_state() -> None

Set the global rng to the given state.

Source code in neps/state/seed_snapshot.py
def set_as_global_seed_state(self) -> None:
    """Set the global rng to the given state."""
    np.random.set_state(self.np_rng)
    random.setstate(self.py_rng)

    if self.torch_rng is not None or self.torch_cuda_rng is not None:
        import torch

        if self.torch_rng is not None:
            torch.random.set_rng_state(self.torch_rng)

        if self.torch_cuda_rng is not None and torch.cuda.is_available():
            torch.cuda.set_rng_state_all(self.torch_cuda_rng)