Source code for arlbench.autorl.autorl_env

"""Automated Reinforcement Learning Environment."""
from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import Any

import gymnasium
import jax
import numpy as np
import pandas as pd
from ConfigSpace import Configuration, ConfigurationSpace

from arlbench.core.algorithms import (
    DQN,
    PPO,
    SAC,
    Algorithm,
    AlgorithmState,
    TrainResult,
    TrainReturnT,
)
from arlbench.core.environments import make_env
from arlbench.utils import config_space_to_gymnasium_space

from .checkpointing import Checkpointer
from .objectives import OBJECTIVES, Objective
from .state_features import STATE_FEATURES, StateFeature

ObservationT = dict[str, np.ndarray]
ObjectivesT = dict[str, float]
InfoT = dict[str, Any]

DEFAULT_AUTO_RL_CONFIG = {
    "seed": 42,
    "env_framework": "gymnax",
    "env_name": "CartPole-v1",
    "env_kwargs": {},
    "eval_env_kwargs": {},
    "n_envs": 10,
    "algorithm": "dqn",
    "cnn_policy": False,
    "deterministic_eval": True,
    "nas_config": {},
    "checkpoint": [],
    "checkpoint_name": "default_checkpoint",
    "checkpoint_dir": "/tmp",
    "objectives": ["reward_mean"],
    "optimize_objectives": "upper",
    "state_features": [],
    "n_steps": 10,
    "n_total_timesteps": 1e5,
    "n_eval_steps": 100,
    "n_eval_episodes": 10,
}


[docs] class AutoRLEnv(gymnasium.Env): """Automated Reinforcement Learning (gynmasium-like) Environment. With each reset, the algorithm state is (re-)initialized. If a checkpoint path is passed to reset, the agent state is initialized with the checkpointed state. In each step, one iteration of training is performed with the current hyperparameter configuration (= action). """ ALGORITHMS = {"ppo": PPO, "dqn": DQN, "sac": SAC} _algorithm: Algorithm _get_obs: Callable[[], np.ndarray] _algorithm_state: AlgorithmState | None _train_result: TrainResult | None _hpo_config: Configuration _total_training_steps: int def __init__(self, config: dict | None = None) -> None: """Creates a new AutoRL environment instance. Args: config (dict | None, optional): Configuration containing keys of DEFAULT_AUTO_RL_CONFIG. If no configuration keys are provided, default configuration is used. Defaults to None. """ super().__init__() self._config = DEFAULT_AUTO_RL_CONFIG.copy() if config: for k, v in config.items(): if k in DEFAULT_AUTO_RL_CONFIG: self._config[k] = v else: warnings.warn( f"Invalid config key '{k}'. This item will be ignored." ) self._seed = int(self._config["seed"]) self._done = True self._total_training_steps = 0 # timesteps across calls of step() self._c_step = 0 # current step self._c_episode = 0 # current episode # Environment self._env = make_env( self._config["env_framework"], self._config["env_name"], n_envs=self._config["n_envs"], env_kwargs=self._config["env_kwargs"], cnn_policy=self._config["cnn_policy"], seed=self._seed, ) self._eval_env = make_env( self._config["env_framework"], self._config["env_name"], n_envs=self._config["n_envs"], env_kwargs=self._config["eval_env_kwargs"], cnn_policy=self._config["cnn_policy"], seed=self._seed + 1, ) # Checkpointing self._checkpoints = [] self._track_metrics = ( "all" in self._config["checkpoint"] or "grad_info" in self._config["state_features"] or "loss" in self._config["checkpoint"] ) self._track_trajectories = ( "all" in self._config["checkpoint"] or "trajectories" in self._config["checkpoint"] ) # Algorithm self._algorithm_cls = self.ALGORITHMS[self._config["algorithm"]] self._config_space = self._algorithm_cls.get_hpo_config_space() # Instantiate algorithm with default hyperparameter configuration self._nas_config = self._algorithm_cls.get_default_nas_config() for k, v in self._config["nas_config"].items(): self._nas_config[k] = v self._hpo_config = self._algorithm_cls.get_default_hpo_config() self._algorithm = self._make_algorithm() self._algorithm_state = None self._train_result = None # Optimization objectives self._objectives = self._get_objectives() # State Features self._state_features = self._get_state_features() self._observation_space = self._get_obs_space() def _get_objectives(self) -> list[Objective]: """Maps the objectives as list of strings to a sorted list of the actual objective classes. Returns: list[Objective]: List of objective classes in the correct order to be wrapped around the train function. """ if len(self._config["objectives"]) == 0: raise ValueError("Please select at least one optimization objective.") objectives = [] cfg_objectives = list(set(self._config["objectives"])) for o in cfg_objectives: if o not in OBJECTIVES: raise ValueError(f"Invalid objective: {o}") objectives += [OBJECTIVES[o]] # Ensure right order of objectives, e.g. runtime is wrapped first objectives = sorted(objectives, key=lambda o: o[1]) # Now we are extracting the actual classes for each objective # They are used to wrap the train function and compute the objective return [o[0] for o in objectives] def _get_state_features(self) -> list[StateFeature]: """Maps the state features as list of strings to a sorted list of the actual state feature classes. Returns: list[StateFeature]: List of state features classes in the correct order to be wrapped around the train function. """ state_features = [] cfg_state_features = list(set(self._config["state_features"])) for f in cfg_state_features: if f not in STATE_FEATURES: raise ValueError(f"Invalid state feature: {f}") state_features += [STATE_FEATURES[f]] return state_features def _get_obs_space(self) -> gymnasium.spaces.Dict: """Returns the state feature space as gymnasium space. Returns: gymnasium.spaces.Dict: Gymnasium space. """ obs_space = {f.KEY: f.get_state_space() for f in self._state_features} obs_space["steps"] = gymnasium.spaces.Box( low=np.array([0, 0]), high=np.array([np.inf, np.inf]) ) return gymnasium.spaces.Dict(obs_space) def _step(self) -> bool: """Fundamental step behaviour, handles truncation. Returns: bool: Whether the episode is done. """ self._c_step += 1 if self._c_step >= self._config["n_steps"]: return True return False def _train(self, **train_kw_args) -> tuple[TrainReturnT, dict, dict]: """Performs the RL training and returns the result as well as objectives and state features. Returns: tuple[TrainReturnT, dict, dict]: Tuple of training result, objectives, and observation (state features). """ assert self._algorithm_state is not None objectives = {} # result are stored here obs = {} # state features are stored here train_func = self._algorithm.train # The objectives are wrapped first since runtime should be accurate for o in self._objectives: train_func = o(train_func, objectives, self._config["optimize_objectives"]) # Then we wrap the state features around the training function obs["steps"] = np.array([self._c_step, self._total_training_steps]) for f in self._state_features: train_func = f(train_func, obs) # Track configuration + budgets using deepcave (https://github.com/automl/DeepCAVE) if self._config.get("deep_cave", False): from deepcave import Objective, Recorder dc_objectives = [Objective(**o.get_spec()) for o in self._objectives] with Recorder(self._config_space, objectives=dc_objectives) as r: r.start(self._hpo_config, self._config["n_total_timesteps"]) result = train_func(*self._algorithm_state, **train_kw_args) r.end(costs=[objectives[o.KEY] for o in self._objectives]) else: result = train_func(*self._algorithm_state, **train_kw_args) return result, objectives, obs def _make_algorithm(self) -> Algorithm: """Instantiated the RL algorithm given the current AutoRL config and hyperparameter configuration. Returns: Algorithm: RL algorithm instance. """ return self._algorithm_cls( self._hpo_config, self._env, nas_config=self._nas_config, eval_env=self._eval_env, track_metrics=self._track_metrics, track_trajectories=self._track_trajectories, cnn_policy=self._config["cnn_policy"], deterministic_eval=self._config["deterministic_eval"], )
[docs] def step( self, action: Configuration | dict, checkpoint_path: str | None = None, n_total_timesteps: int | None = None, n_eval_steps: int | None = None, n_eval_episodes: int | None = None, seed: int | None = None, ) -> tuple[ObservationT, ObjectivesT, bool, bool, InfoT]: """Performs one iteration of RL training. Args: action (Configuration | dict): Hyperparameter configuration to use for training. n_total_timesteps (int | None, optional): Number of total training steps. Defaults to None. n_eval_steps (int | None, optional): Number of evaluations during training. Defaults to None. n_eval_episodes (int | None, optional): Number of episodes to run per evalution during training. Defaults to None. seed (int | None, optional): Random seed. Defaults to None. If None, seed of the AutoRL environment is used. Raises: ValueError: Error is raised if step() is called before reset() was called. Returns: tuple[ObservationT, ObjectivesT, bool, bool, InfoT]: State information, objectives, terminated, truncated, additional information. """ if len(action.keys()) == 0: # no action provided warnings.warn( "No agent configuration provided. Falling back to default configuration." ) if self._done: raise ValueError("Called step() before reset().") # Set done if max. number of steps in DAC is reached or classic (one-step) HPO is performed self._done = self._step() info = {} # Apply changes to current hyperparameter configuration and reinstantiate algorithm if isinstance(action, dict): action = Configuration(self.config_space, action) self._hpo_config = action seed = seed if seed else self._seed self._algorithm = self._make_algorithm() # First, we check if there is a checkpoint to load. If not, we have to check # whether this is the first iteration, i.e., call of env.step(). In that case, # we have to initialiaze the algorithm state. # Otherwise, we are using the state from previous iteration(s) if checkpoint_path: try: self._algorithm_state = self._load(checkpoint_path, seed) except Exception as e: # noqa: BLE001 print(e) init_rng = jax.random.key(seed) self._algorithm_state = self._algorithm.init(init_rng) elif self._algorithm_state is None: init_rng = jax.random.key(seed) self._algorithm_state = self._algorithm.init(init_rng) # Training kwargs train_kw_args = { "n_total_timesteps": n_total_timesteps if n_total_timesteps else self._config["n_total_timesteps"], "n_eval_steps": n_eval_steps if n_eval_steps else self._config["n_eval_steps"], "n_eval_episodes": n_eval_episodes if n_eval_episodes else self._config["n_eval_episodes"], } # Perform one iteration of training result, objectives, obs = self._train(**train_kw_args) self._algorithm_state, self._train_result = result steps = ( np.arange(1, train_kw_args["n_eval_steps"] + 1) * train_kw_args["n_total_timesteps"] // train_kw_args["n_eval_steps"] ) returns = self._train_result.eval_rewards.mean(axis=1) info["train_info_df"] = pd.DataFrame({"steps": steps, "returns": returns}) self._total_training_steps += train_kw_args["n_total_timesteps"] # Checkpointing if len(self._config["checkpoint"]) > 0: assert self._algorithm_state is not None checkpoint = self._save() self._checkpoints += [checkpoint] info["checkpoint"] = checkpoint return obs, objectives, False, self._done, info
[docs] def reset( self ) -> tuple[ObservationT, InfoT]: """Resets the AutoRL environment and current algorithm state. Returns: tuple[ObservationT, InfoT]: Empty observation and state information. """ self._done = False self._c_step = 0 self._c_episode += 1 self._algorithm_state = None return {}, {}
def _save(self, tag: str | None = None) -> str: """Saves the current algorithm state and training result. Args: tag (str | None, optional): Checkpoint tag. Defaults to None. Returns: str: Checkpoint path. """ if self._algorithm_state is None: warnings.warn("Agent not initialized. Not able to save agent state.") return "" if self._train_result is None: warnings.warn( "No training performed, so there is nothing to save. Please run step() first." ) return Checkpointer.save( self._algorithm.name, self._algorithm_state, self._config, self._hpo_config, self._done, self._c_episode, self._c_step, self._train_result, tag=tag, ) def _load(self, checkpoint_path: str, seed: int) -> AlgorithmState: """Load the algorithm state from a checkpoint. Args: checkpoint_path (str): Path of the checkpoint to load. seed (int | None, optional): Random seed to use for algorithm initialization. Defaults to None. """ init_rng = jax.random.PRNGKey(seed) algorithm_state = self._algorithm.init(init_rng) ( ( _, self._c_step, self._c_episode, ), algorithm_kw_args, ) = Checkpointer.load(checkpoint_path, algorithm_state) return self._algorithm.init(init_rng, **algorithm_kw_args) @property def action_space(self) -> gymnasium.spaces.Space: """Returns the hyperparameter configuration spaces as gymnasium space. Returns: gymnasium.spaces.Space: Hyperparameter configuration space. """ return config_space_to_gymnasium_space(self._config_space) @property def config_space(self) -> ConfigurationSpace: """Returns the hyperparameter configuration spaces as ConfigSpace. Returns: ConfigurationSpace: Hyperparameter configuration space. """ return self._config_space @property def observation_space(self) -> gymnasium.spaces.Space: """Returns a gymnasium spaces of state features (observations). Returns: gymnasium.spaces.Space: Gynasium space. """ return self._observation_space @property def hpo_config(self) -> Configuration: """Returns the current hyperparameter configuration stored in the AutoRL environment.. Returns: Configuration: Hyperparameter configuration. """ return self._hpo_config @property def checkpoints(self) -> list[str]: """Returns a list of created checkpoints for this AutoRL environment. Returns: list[str]: List of checkpoint paths. """ return list(self._checkpoints) @property def objectives(self) -> list[str]: """Returns configured objectives. Returns: list[str]: List of objectives. """ return [o.__name__ for o in self._objectives] @property def config(self) -> dict: """Returns the AutoRL configuration. Returns: dict: AutoRL configuration. """ return self._config.copy()
[docs] def eval(self, num_eval_episodes: int) -> np.ndarray: """Evaluates the algorithm using its current training state. Args: num_eval_episodes (int): Number of evaluation episodes to run. Returns: np.ndarray: Array of evaluation return for each episode. """ if self._algorithm is None or self._algorithm_state is None: raise ValueError("Agent not initialized. Call reset() first.") rewards = self._algorithm.eval( self._algorithm_state.runner_state, num_eval_episodes ) return np.array(rewards)