arlbench.core.running_statistics¶
Running statistics.
Functions
| 
 | Denormalizes values in a nested structure using the given mean/std. | 
| 
 | Initializes the running statistics for the given nested structure. | 
| 
 | Normalizes data using running statistics. | 
| 
 | Updates the running statistics with the given batch of data. | 
Classes
| 
 | Describes a numpy array or scalar shape and dtype. | 
| 
 | A container for running statistics (mean, std) of possibly nested data. | 
| 
 | Full state of running statistics computation. | 
- class arlbench.core.running_statistics.Array(shape, dtype)[source]¶
- Bases: - object- Describes a numpy array or scalar shape and dtype. - Similar to dm_env.specs.Array. 
- class arlbench.core.running_statistics.NestedMeanStd(mean, std)[source]¶
- Bases: - object- A container for running statistics (mean, std) of possibly nested data. - replace(**updates)¶
- “Returns a new object replacing the specified fields with new values. 
 
- class arlbench.core.running_statistics.RunningStatisticsState(mean, std, count, summed_variance)[source]¶
- Bases: - NestedMeanStd- Full state of running statistics computation. - replace(**updates)¶
- “Returns a new object replacing the specified fields with new values. 
 
- arlbench.core.running_statistics.denormalize(batch, mean_std)[source]¶
- Denormalizes values in a nested structure using the given mean/std. - Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy. - Parameters:
- batch ( - Array) – a nested structure containing batch of data.
- mean_std ( - NestedMeanStd) – mean and standard deviation used for denormalization.
 
- Return type:
- Array
- Returns:
- Nested structure with denormalized values. 
 
- arlbench.core.running_statistics.init_state(nest)[source]¶
- Initializes the running statistics for the given nested structure. - Return type:
 
- arlbench.core.running_statistics.normalize(batch, mean_std, max_abs_value=None)[source]¶
- Normalizes data using running statistics. - Return type:
- Array
 
- arlbench.core.running_statistics.update(state, batch, *, weights=None, std_min_value=1e-06, std_max_value=1000000.0, pmap_axis_name=None, validate_shapes=True)[source]¶
- Updates the running statistics with the given batch of data. - Note: data batch and state elements (mean, etc.) must have the same structure. - Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision - Parameters:
- state ( - RunningStatisticsState) – The running statistics before the update.
- batch ( - Union[- Array,- Any,- Array,- Iterable[NestedSpec],- Mapping[- Any, NestedSpec]]) – The data to be used to update the running statistics.
- weights ( - Optional[- Array]) – Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice.
- std_min_value ( - float) – Minimum value for the standard deviation.
- std_max_value ( - float) – Maximum value for the standard deviation.
- pmap_axis_name ( - Optional[- str]) – Name of the pmapped axis, if any.
- validate_shapes ( - bool) – If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn’t impact performance when jitted.
 
- Return type:
- Returns:
- Updated running statistics.