arlbench.core.running_statistics¶
Running statistics.
Functions
|
Denormalizes values in a nested structure using the given mean/std. |
|
Initializes the running statistics for the given nested structure. |
|
Normalizes data using running statistics. |
|
Updates the running statistics with the given batch of data. |
Classes
|
Describes a numpy array or scalar shape and dtype. |
|
A container for running statistics (mean, std) of possibly nested data. |
|
Full state of running statistics computation. |
- class arlbench.core.running_statistics.Array(shape, dtype)[source]¶
Bases:
object
Describes a numpy array or scalar shape and dtype.
Similar to dm_env.specs.Array.
- class arlbench.core.running_statistics.NestedMeanStd(mean, std)[source]¶
Bases:
object
A container for running statistics (mean, std) of possibly nested data.
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- class arlbench.core.running_statistics.RunningStatisticsState(mean, std, count, summed_variance)[source]¶
Bases:
NestedMeanStd
Full state of running statistics computation.
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- arlbench.core.running_statistics.denormalize(batch, mean_std)[source]¶
Denormalizes values in a nested structure using the given mean/std.
Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy.
- Parameters:
batch (
Array
) – a nested structure containing batch of data.mean_std (
NestedMeanStd
) – mean and standard deviation used for denormalization.
- Return type:
Array
- Returns:
Nested structure with denormalized values.
- arlbench.core.running_statistics.init_state(nest)[source]¶
Initializes the running statistics for the given nested structure.
- Return type:
- arlbench.core.running_statistics.normalize(batch, mean_std, max_abs_value=None)[source]¶
Normalizes data using running statistics.
- Return type:
Array
- arlbench.core.running_statistics.update(state, batch, *, weights=None, std_min_value=1e-06, std_max_value=1000000.0, pmap_axis_name=None, validate_shapes=True)[source]¶
Updates the running statistics with the given batch of data.
Note: data batch and state elements (mean, etc.) must have the same structure.
Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
- Parameters:
state (
RunningStatisticsState
) – The running statistics before the update.batch (
Union
[Array
,Any
,Array
,Iterable
[NestedSpec],Mapping
[Any
, NestedSpec]]) – The data to be used to update the running statistics.weights (
Optional
[Array
]) – Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice.std_min_value (
float
) – Minimum value for the standard deviation.std_max_value (
float
) – Maximum value for the standard deviation.pmap_axis_name (
Optional
[str
]) – Name of the pmapped axis, if any.validate_shapes (
bool
) – If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn’t impact performance when jitted.
- Return type:
- Returns:
Updated running statistics.