MightyReplay(
capacity,
keep_infos=False,
flatten_infos=False,
device: device | str = "cpu",
)
Bases: MightyBuffer
Simple replay buffer.
:param capacity: Buffer size
:param random_seed: Seed for sampling
:param keep_infos: Keep the extra info dict. Required for some algorithms.
:param flatten_infos: Make flat list from infos.
Might be necessary, depending on info content.
:return:
Source code in mighty/mighty_replay/mighty_replay_buffer.py
| def __init__(
self,
capacity,
keep_infos=False,
flatten_infos=False,
device: torch.device | str = "cpu",
):
"""Initialize Buffer.
:param capacity: Buffer size
:param random_seed: Seed for sampling
:param keep_infos: Keep the extra info dict. Required for some algorithms.
:param flatten_infos: Make flat list from infos.
Might be necessary, depending on info content.
:return:
"""
self.capacity = capacity
self.keep_infos = keep_infos
self.flatten_infos = flatten_infos
self.device = torch.device(device)
self.rng = np.random.default_rng()
self.reset()
|
full
property
Check if the buffer is full.
add
Add transition(s).
:param transition_batch: Transition(s) to add
:param metrics: Current metrics dict
:return:
Source code in mighty/mighty_replay/mighty_replay_buffer.py
| def add(self, transition_batch, _):
"""Add transition(s).
:param transition_batch: Transition(s) to add
:param metrics: Current metrics dict
:return:
"""
if not self.keep_infos:
transition_batch.extra_info = []
elif self.flatten_infos:
transition_batch.extra_info = [
list(flatten_infos(transition_batch.extra_info))
]
self.index += transition_batch.size
if len(self.obs) == 0:
self.obs = transition_batch.observations
self.next_obs = transition_batch.next_obs
self.actions = transition_batch.actions
self.rewards = transition_batch.rewards
self.dones = transition_batch.dones
else:
self.obs = torch.cat((self.obs, transition_batch.observations))
self.next_obs = torch.cat((self.next_obs, transition_batch.next_obs))
self.actions = torch.cat((self.actions, transition_batch.actions))
self.rewards = torch.cat((self.rewards, transition_batch.rewards))
self.dones = torch.cat((self.dones, transition_batch.dones))
if len(self) > self.capacity:
self.obs = self.obs[len(self) - self.capacity :]
self.next_obs = self.next_obs[len(self) - self.capacity :]
self.actions = self.actions[len(self) - self.capacity :]
self.rewards = self.rewards[len(self) - self.capacity :]
self.dones = self.dones[len(self) - self.capacity :]
self.index = self.capacity
|
reset
Reset the buffer.
Source code in mighty/mighty_replay/mighty_replay_buffer.py
| def reset(self):
"""Reset the buffer."""
self.obs = []
self.next_obs = []
self.actions = []
self.rewards = []
self.dones = []
self.index = 0
|
sample
Sample transitions.
Source code in mighty/mighty_replay/mighty_replay_buffer.py
| def sample(self, batch_size=32):
"""Sample transitions."""
batch_indices = self.rng.choice(np.arange(len(self)), size=batch_size)
return TransitionBatch(
self.obs[batch_indices],
self.actions[batch_indices],
self.rewards[batch_indices],
self.next_obs[batch_indices],
self.dones[batch_indices],
device=self.device,
)
|
save
save(filename='buffer.pkl')
Save the buffer to a file.
Source code in mighty/mighty_replay/mighty_replay_buffer.py
| def save(self, filename="buffer.pkl"):
"""Save the buffer to a file."""
with open(filename, "wb") as f:
pickle.dump(self, f)
|
seed
Set random seed.
Source code in mighty/mighty_replay/buffer.py
| def seed(self, seed: int):
"""Set random seed."""
self.rng = np.random.default_rng(seed)
|