arlbench.autorl.state_features¶
State features for the AutoRL environment.
Classes
|
Gradient information state feature for the AutoRL environment. |
|
Loss information state feature for the AutoRL environment. |
|
Prediction information state feature for the AutoRL environment. |
|
An abstract state features for the AutoRL environment. |
|
Weight information state feature for the AutoRL environment. |
- class arlbench.autorl.state_features.GradInfo(*args, **kwargs)[source]¶
Bases:
StateFeatureGradient information state feature for the AutoRL environment. It contains the grad norm during training.
- static __call__(train_func, state_features)[source]¶
Wraps the training function with the gradient information calculation.
- Return type:
Callable[[DQNRunnerState|PPORunnerState|SACRunnerState,PrioritisedTrajectoryBufferState,int|None,int|None,int|None],tuple[DQNState,DQNTrainingResult] |tuple[PPOState,PPOTrainingResult] |tuple[SACState,SACTrainingResult]]
- class arlbench.autorl.state_features.LossInfo(*args, **kwargs)[source]¶
Bases:
StateFeatureLoss information state feature for the AutoRL environment. It contains the loss mean and variance.
- static __call__(train_func, state_features)[source]¶
Wraps the training function with the loss information calculation.
- Return type:
Callable[[DQNRunnerState|PPORunnerState|SACRunnerState,PrioritisedTrajectoryBufferState,int|None,int|None,int|None],tuple[DQNState,DQNTrainingResult] |tuple[PPOState,PPOTrainingResult] |tuple[SACState,SACTrainingResult]]
- class arlbench.autorl.state_features.PredictionInfo(*args, **kwargs)[source]¶
Bases:
StateFeaturePrediction information state feature for the AutoRL environment. It contains the predicted values and log probs.
- static __call__(train_func, state_features)[source]¶
Wraps the training function with the prediction information calculation.
- Return type:
Callable[[DQNRunnerState|PPORunnerState|SACRunnerState,PrioritisedTrajectoryBufferState,int|None,int|None,int|None],tuple[DQNState,DQNTrainingResult] |tuple[PPOState,PPOTrainingResult] |tuple[SACState,SACTrainingResult]]
- class arlbench.autorl.state_features.StateFeature(*args, **kwargs)[source]¶
Bases:
ABCAn abstract state features for the AutoRL environment.
It can be wrapped around the training function to calculate the state features. We do this be overriding the __new__() function. It allows us to imitate the behaviour of a basic function while keeping the advantages of a static class.
- abstract static __call__(train_func, state_features)[source]¶
Wraps the training function with the state feature calculation.
- Parameters:
train_func (TrainFunc) – Training function to wrap.
state_features (dict) – Dictionary to store state features.
- Returns:
Wrapped training function.
- Return type:
TrainFunc
- class arlbench.autorl.state_features.WeightInfo(*args, **kwargs)[source]¶
Bases:
StateFeatureWeight information state feature for the AutoRL environment. It contains statistics about the weights during training.
- static __call__(train_func, state_features)[source]¶
Wraps the training function with the gradient information calculation.
- Return type:
Callable[[DQNRunnerState|PPORunnerState|SACRunnerState,PrioritisedTrajectoryBufferState,int|None,int|None,int|None],tuple[DQNState,DQNTrainingResult] |tuple[PPOState,PPOTrainingResult] |tuple[SACState,SACTrainingResult]]