arlbench.autorl.state_features

State features for the AutoRL environment.

Classes

GradInfo(*args, **kwargs)

Gradient information state feature for the AutoRL environment.

StateFeature(*args, **kwargs)

An abstract state features for the AutoRL environment.

class arlbench.autorl.state_features.GradInfo(*args, **kwargs)[source]

Bases: StateFeature

Gradient information state feature for the AutoRL environment. It contains the grad norm during training.

static __call__(train_func, state_features)[source]

Wraps the training function with the gradient information calculation.

Return type:

Callable[[DQNRunnerState | PPORunnerState | SACRunnerState, PrioritisedTrajectoryBufferState, int | None, int | None, int | None], tuple[DQNState, DQNTrainingResult] | tuple[PPOState, PPOTrainingResult] | tuple[SACState, SACTrainingResult]]

static get_state_space()[source]

Returns state space.

Return type:

Space

class arlbench.autorl.state_features.StateFeature(*args, **kwargs)[source]

Bases: ABC

An abstract state features for the AutoRL environment.

It can be wrapped around the training function to calculate the state features. We do this be overriding the __new__() function. It allows us to imitate the behaviour of a basic function while keeping the advantages of a static class.

abstract static __call__(train_func, state_features)[source]

Wraps the training function with the state feature calculation.

Parameters:
  • train_func (TrainFunc) – Training function to wrap.

  • state_features (dict) – Dictionary to store state features.

Returns:

Wrapped training function.

Return type:

TrainFunc

static __new__(cls, *args, **kwargs)[source]

Creates a new instance of this state feature and directly wraps the train function.

This is done by first creating an object and subsequently calling self.__call__().

Returns:

Wrapped training function.

Return type:

TrainFunc

abstract static get_state_space()[source]

Returns a dictionary containing the specification of the state feature.

Returns:

Specification.

Return type:

dict