arlbench.autorl.state_features¶
State features for the AutoRL environment.
Classes
| 
 | Gradient information state feature for the AutoRL environment. | 
| 
 | Loss information state feature for the AutoRL environment. | 
| 
 | Prediction information state feature for the AutoRL environment. | 
| 
 | An abstract state features for the AutoRL environment. | 
| 
 | Weight information state feature for the AutoRL environment. | 
- class arlbench.autorl.state_features.GradInfo(*args, **kwargs)[source]¶
- Bases: - StateFeature- Gradient information state feature for the AutoRL environment. It contains the grad norm during training. - static __call__(train_func, state_features)[source]¶
- Wraps the training function with the gradient information calculation. - Return type:
- Callable[[- DQNRunnerState|- PPORunnerState|- SACRunnerState,- PrioritisedTrajectoryBufferState,- int|- None,- int|- None,- int|- None],- tuple[- DQNState,- DQNTrainingResult] |- tuple[- PPOState,- PPOTrainingResult] |- tuple[- SACState,- SACTrainingResult]]
 
 
- class arlbench.autorl.state_features.LossInfo(*args, **kwargs)[source]¶
- Bases: - StateFeature- Loss information state feature for the AutoRL environment. It contains the loss mean and variance. - static __call__(train_func, state_features)[source]¶
- Wraps the training function with the loss information calculation. - Return type:
- Callable[[- DQNRunnerState|- PPORunnerState|- SACRunnerState,- PrioritisedTrajectoryBufferState,- int|- None,- int|- None,- int|- None],- tuple[- DQNState,- DQNTrainingResult] |- tuple[- PPOState,- PPOTrainingResult] |- tuple[- SACState,- SACTrainingResult]]
 
 
- class arlbench.autorl.state_features.PredictionInfo(*args, **kwargs)[source]¶
- Bases: - StateFeature- Prediction information state feature for the AutoRL environment. It contains the predicted values and log probs. - static __call__(train_func, state_features)[source]¶
- Wraps the training function with the prediction information calculation. - Return type:
- Callable[[- DQNRunnerState|- PPORunnerState|- SACRunnerState,- PrioritisedTrajectoryBufferState,- int|- None,- int|- None,- int|- None],- tuple[- DQNState,- DQNTrainingResult] |- tuple[- PPOState,- PPOTrainingResult] |- tuple[- SACState,- SACTrainingResult]]
 
 
- class arlbench.autorl.state_features.StateFeature(*args, **kwargs)[source]¶
- Bases: - ABC- An abstract state features for the AutoRL environment. - It can be wrapped around the training function to calculate the state features. We do this be overriding the __new__() function. It allows us to imitate the behaviour of a basic function while keeping the advantages of a static class. - abstract static __call__(train_func, state_features)[source]¶
- Wraps the training function with the state feature calculation. - Parameters:
- train_func (TrainFunc) – Training function to wrap. 
- state_features (dict) – Dictionary to store state features. 
 
- Returns:
- Wrapped training function. 
- Return type:
- TrainFunc 
 
 
- class arlbench.autorl.state_features.WeightInfo(*args, **kwargs)[source]¶
- Bases: - StateFeature- Weight information state feature for the AutoRL environment. It contains statistics about the weights during training. - static __call__(train_func, state_features)[source]¶
- Wraps the training function with the gradient information calculation. - Return type:
- Callable[[- DQNRunnerState|- PPORunnerState|- SACRunnerState,- PrioritisedTrajectoryBufferState,- int|- None,- int|- None,- int|- None],- tuple[- DQNState,- DQNTrainingResult] |- tuple[- PPOState,- PPOTrainingResult] |- tuple[- SACState,- SACTrainingResult]]