arlbench.core.running_statistics

Running statistics.

Functions

denormalize(batch, mean_std)

Denormalizes values in a nested structure using the given mean/std.

init_state(nest)

Initializes the running statistics for the given nested structure.

normalize(batch, mean_std[, max_abs_value])

Normalizes data using running statistics.

update(state, batch, *[, weights, ...])

Updates the running statistics with the given batch of data.

Classes

Array(shape, dtype)

Describes a numpy array or scalar shape and dtype.

NestedMeanStd(mean, std)

A container for running statistics (mean, std) of possibly nested data.

RunningStatisticsState(mean, std, count, ...)

Full state of running statistics computation.

class arlbench.core.running_statistics.Array(shape, dtype)[source]

Bases: object

Describes a numpy array or scalar shape and dtype.

Similar to dm_env.specs.Array.

class arlbench.core.running_statistics.NestedMeanStd(mean, std)[source]

Bases: object

A container for running statistics (mean, std) of possibly nested data.

replace(**updates)

“Returns a new object replacing the specified fields with new values.

class arlbench.core.running_statistics.RunningStatisticsState(mean, std, count, summed_variance)[source]

Bases: NestedMeanStd

Full state of running statistics computation.

replace(**updates)

“Returns a new object replacing the specified fields with new values.

arlbench.core.running_statistics.denormalize(batch, mean_std)[source]

Denormalizes values in a nested structure using the given mean/std.

Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy.

Parameters:
  • batch (Array) – a nested structure containing batch of data.

  • mean_std (NestedMeanStd) – mean and standard deviation used for denormalization.

Return type:

Array

Returns:

Nested structure with denormalized values.

arlbench.core.running_statistics.init_state(nest)[source]

Initializes the running statistics for the given nested structure.

Return type:

RunningStatisticsState

arlbench.core.running_statistics.normalize(batch, mean_std, max_abs_value=None)[source]

Normalizes data using running statistics.

Return type:

Array

arlbench.core.running_statistics.update(state, batch, *, weights=None, std_min_value=1e-06, std_max_value=1000000.0, pmap_axis_name=None, validate_shapes=True)[source]

Updates the running statistics with the given batch of data.

Note: data batch and state elements (mean, etc.) must have the same structure.

Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

Parameters:
  • state (RunningStatisticsState) – The running statistics before the update.

  • batch (Union[Array, Any, Array, Iterable[NestedSpec], Mapping[Any, NestedSpec]]) – The data to be used to update the running statistics.

  • weights (Optional[Array]) – Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice.

  • std_min_value (float) – Minimum value for the standard deviation.

  • std_max_value (float) – Maximum value for the standard deviation.

  • pmap_axis_name (Optional[str]) – Name of the pmapped axis, if any.

  • validate_shapes (bool) – If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn’t impact performance when jitted.

Return type:

RunningStatisticsState

Returns:

Updated running statistics.