arlbench.autorl.state_features¶
State features for the AutoRL environment.
Classes
|
Gradient information state feature for the AutoRL environment. |
|
An abstract state features for the AutoRL environment. |
- class arlbench.autorl.state_features.GradInfo(*args, **kwargs)[source]¶
Bases:
StateFeature
Gradient information state feature for the AutoRL environment. It contains the grad norm during training.
- static __call__(train_func, state_features)[source]¶
Wraps the training function with the gradient information calculation.
- Return type:
Callable
[[DQNRunnerState
|PPORunnerState
|SACRunnerState
,PrioritisedTrajectoryBufferState
,int
|None
,int
|None
,int
|None
],tuple
[DQNState
,DQNTrainingResult
] |tuple
[PPOState
,PPOTrainingResult
] |tuple
[SACState
,SACTrainingResult
]]
- class arlbench.autorl.state_features.StateFeature(*args, **kwargs)[source]¶
Bases:
ABC
An abstract state features for the AutoRL environment.
It can be wrapped around the training function to calculate the state features. We do this be overriding the __new__() function. It allows us to imitate the behaviour of a basic function while keeping the advantages of a static class.
- abstract static __call__(train_func, state_features)[source]¶
Wraps the training function with the state feature calculation.
- Parameters:
train_func (TrainFunc) – Training function to wrap.
state_features (dict) – Dictionary to store state features.
- Returns:
Wrapped training function.
- Return type:
TrainFunc