arlbench.core package¶
Subpackages¶
- arlbench.core.algorithms package- Subpackages
- Submodules
- arlbench.core.algorithms.algorithm module- Algorithm- Algorithm.action_type
- Algorithm.eval()
- Algorithm.get_checkpoint_factory()
- Algorithm.get_default_hpo_config()
- Algorithm.get_default_nas_config()
- Algorithm.get_hpo_config_space()
- Algorithm.get_nas_config_space()
- Algorithm.init()
- Algorithm.name
- Algorithm.predict()
- Algorithm.train()
- Algorithm.update_hpo_config()
 
 
- arlbench.core.algorithms.buffers module
- arlbench.core.algorithms.common module
- arlbench.core.algorithms.prioritised_item_buffer module
- Module contents- Algorithm- Algorithm.action_type
- Algorithm.eval()
- Algorithm.get_checkpoint_factory()
- Algorithm.get_default_hpo_config()
- Algorithm.get_default_nas_config()
- Algorithm.get_hpo_config_space()
- Algorithm.get_nas_config_space()
- Algorithm.init()
- Algorithm.name
- Algorithm.predict()
- Algorithm.train()
- Algorithm.update_hpo_config()
 
- DQN
- PPO
- SAC
 
 
- arlbench.core.environments package- Submodules
- arlbench.core.environments.autorl_env module
- arlbench.core.environments.brax_env module
- arlbench.core.environments.envpool_env module
- arlbench.core.environments.gymnasium_env module
- arlbench.core.environments.gymnax_env module
- arlbench.core.environments.make_env module
- arlbench.core.environments.xland_env module
- Module contents
 
- arlbench.core.wrappers package
Submodules¶
arlbench.core.running_statistics module¶
Running statistics.
- class arlbench.core.running_statistics.Array(shape, dtype)[source]¶
- Bases: - object- Describes a numpy array or scalar shape and dtype. - Similar to dm_env.specs.Array. - 
dtype: dtype¶
 - 
shape: tuple[int,...]¶
 
- 
dtype: 
- class arlbench.core.running_statistics.NestedMeanStd(mean, std)[source]¶
- Bases: - object- A container for running statistics (mean, std) of possibly nested data. - replace(**updates)¶
- “Returns a new object replacing the specified fields with new values. 
 
- class arlbench.core.running_statistics.RunningStatisticsState(mean, std, count, summed_variance)[source]¶
- Bases: - NestedMeanStd- Full state of running statistics computation. - 
count: Array¶
 - replace(**updates)¶
- “Returns a new object replacing the specified fields with new values. 
 
- 
count: 
- arlbench.core.running_statistics.denormalize(batch, mean_std)[source]¶
- Denormalizes values in a nested structure using the given mean/std. - Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy. - Parameters:
- batch ( - Array) – a nested structure containing batch of data.
- mean_std ( - NestedMeanStd) – mean and standard deviation used for denormalization.
 
- Return type:
- Array
- Returns:
- Nested structure with denormalized values. 
 
- arlbench.core.running_statistics.init_state(nest)[source]¶
- Initializes the running statistics for the given nested structure. - Return type:
 
- arlbench.core.running_statistics.normalize(batch, mean_std, max_abs_value=None)[source]¶
- Normalizes data using running statistics. - Return type:
- Array
 
- arlbench.core.running_statistics.update(state, batch, *, weights=None, std_min_value=1e-06, std_max_value=1000000.0, pmap_axis_name=None, validate_shapes=True)[source]¶
- Updates the running statistics with the given batch of data. - Note: data batch and state elements (mean, etc.) must have the same structure. - Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision - Parameters:
- state ( - RunningStatisticsState) – The running statistics before the update.
- batch ( - Union[- Array,- Any,- Array,- Iterable[NestedSpec],- Mapping[- Any, NestedSpec]]) – The data to be used to update the running statistics.
- weights ( - Optional[- Array]) – Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice.
- std_min_value ( - float) – Minimum value for the standard deviation.
- std_max_value ( - float) – Maximum value for the standard deviation.
- pmap_axis_name ( - Optional[- str]) – Name of the pmapped axis, if any.
- validate_shapes ( - bool) – If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn’t impact performance when jitted.
 
- Return type:
- Returns:
- Updated running statistics.