arlbench.autorl.checkpointing¶
Contains all checkpointing-related methods for the AutoRL environment.
Classes
Contains all checkpointing-related methods for the AutoRL environment. |
- class arlbench.autorl.checkpointing.Checkpointer[source]¶
Bases:
object
Contains all checkpointing-related methods for the AutoRL environment.
- static load(checkpoint_path, algorithm_state)[source]¶
Loads a AutoRL environment checkpoint.
- Parameters:
checkpoint_path (str) – Path of the checkpoint.
algorithm_state (AlgorithmState) – Current algorithm state, certain attributes will be overriden by checkpoint.
- Returns:
Common AutoRL environment attributes as well as dictionary to restored algorithm state: (hp_config, c_step, c_episode), algorithm_kw_args
- Return type:
tuple[tuple[dict[str, Any], int, int], dict]
- static load_buffer(dummy_buffer_state, priority_state_path, buffer_dir, vault_uuid)[source]¶
Loads the buffer state from a checkpoint.
- Parameters:
dummy_buffer_state (PrioritisedTrajectoryBufferState) – Dummy buffer state. This is required to know the size/data types of the buffer.
priority_state_path (str) – Path where the priorities are stored.
buffer_dir (str) – The directory where the buffer data is stored.
vault_uuid (str) – The unique ID of the vault containing the buffer data.
- Returns:
The buffer state that was loaded from disk.
- Return type:
PrioritisedTrajectoryBufferState
- static save(algorithm, algorithm_state, autorl_config, hp_config, done, c_episode, c_step, train_result, tag=None)[source]¶
Saves the current state of a AutoRL environment.
- Parameters:
algorithm (str) – Name of the algorithm.
algorithm_state (AlgorithmState) – Current algorithm state.
autorl_config (dict) – AutoRL configuration.
hp_config (Configuration) – Hyperparameter configuration of the algorithm.
done (bool) – Whether the environment is done.
c_episode (int) – Current episode of the AutoRL environment.
c_step (int) – Current step of the AutoRL environment.
train_result (TrainResult | None) – Last training result of the algorithm.
tag (str | None, optional) – Checkpoint tag which is appended to the checkpoint name. Defaults to None.
- Returns:
Path of the checkpoint.
- Return type:
str
- static save_buffer(buffer_state, checkpoint_dir, checkpoint_name)[source]¶
Saves the buffer state of an algorithm.
- Parameters:
buffer_state (TrajectoryBufferState | PrioritisedTrajectoryBufferState) – Buffer state.
checkpoint_dir (str) – Checkpoint directory.
checkpoint_name (str) – Checkpoint name.
- Returns:
Dictionary containing the identifiers of single parts of the buffer. Required to load the checkpoint.
- Return type:
dict