Source code for carl.envs.carl_env

from __future__ import annotations

import abc
from typing import Any, SupportsFloat, TypeVar

import inspect

import gymnasium
from gymnasium import Wrapper, spaces
from gymnasium.core import Env

from carl.context.context_space import ContextFeature, ContextSpace
from carl.context.selection import AbstractSelector, RoundRobinSelector
from carl.utils.types import Context, Contexts

ObsType = TypeVar("ObsType")


[docs] class CARLEnv(Wrapper, abc.ABC): def __init__( self, env: Env, contexts: Contexts | None = None, obs_context_features: list[str] | None = None, obs_context_as_dict: bool = True, context_selector: AbstractSelector | type[AbstractSelector] | None = None, context_selector_kwargs: dict | None = None, **kwargs, ): """Base CARL wrapper. Good to know: - The observation always is a dictionary of {"state": ..., "context": ...}. Use an observation flattening wrapper if you need a different format. - After each env reset, a new context is selected by the context selector. - The context set is always filled with defaults if missing. Parameters ---------- env : Env Environment adhering to gymnasium API. contexts : Contexts, optional The context set, by default None. obs_context_features : list[str], optional The context features which should be added to the state, by default None. If None, add all available context features. obs_context_as_dict : bool, optional Whether to pass the context as a vector or a dict in the observations. The default is True. context_selector : AbstractSelector | type[AbstractSelector] | None The context selector selecting a new context after each env reset, by default None. If None, use a round robin selector. Can be an object or class. For the latter, you can pass kwargs. context_selector_kwargs : dict, optional Keyword arguments for the context selector if it is passed as a class. Attributes ---------- base_observation_space: gymnasium.spaces.Space The observation space from the wrapped environment. obs_context_as_dict: bool, optional Whether to pass the context as a vector or a dict in the observations. The default is True. observation_space: gymnasium.spaces.Dict The observation space of the CARL environment which is a dictionary of "state" and "context". contexts: Contexts The context set. context_selector: ContextSelector. The context selector selecting a new context after each env reset. """ super().__init__(env) self.base_observation_space: gymnasium.spaces.Space = env.observation_space self.obs_context_as_dict = obs_context_as_dict if contexts is None: contexts = { 0: self.get_default_context() } # was self.get_default_context(self) before self.contexts = contexts self.context: Context | None = None # Set by `_progress_instance` if obs_context_features is None: obs_context_features = list(list(self.contexts.values())[0].keys()) self.obs_context_features = obs_context_features # Context Selector self.context_selector: type[AbstractSelector] if context_selector is None: self.context_selector = RoundRobinSelector(contexts=contexts) # type: ignore [assignment] elif isinstance(context_selector, AbstractSelector): self.context_selector = context_selector # type: ignore [assignment] elif inspect.isclass(context_selector) and issubclass( context_selector, AbstractSelector ): if context_selector_kwargs is None: context_selector_kwargs = {} _context_selector_kwargs = {"contexts": contexts} context_selector_kwargs.update(_context_selector_kwargs) self.context_selector = context_selector(**context_selector_kwargs) # type: ignore [assignment] else: raise ValueError( f"Context selector must be None or an AbstractSelector class or instance. " f"Got type {type(context_selector)}." ) self.observation_space: gymnasium.spaces.Dict = self.get_observation_space( obs_context_feature_names=self.obs_context_features ) @property def contexts(self) -> Contexts: return self._contexts @property def context_id(self): return self.context_selector.context_id @contexts.setter def contexts(self, contexts: Contexts) -> None: """Set `contexts` property For each context maybe fill with default context values. This is only necessary whenever we update the contexts, so here is the right place. Parameters ---------- contexts : Contexts Contexts to set """ context_space = self.get_context_space() contexts = {k: context_space.insert_defaults(v) for k, v in contexts.items()} self._contexts = contexts @context_id.setter def context_id(self, new_id) -> None: """Set `context_id` property This will switch the context ID of the context selector. Realistically you'll want to only use this if your selector does not automaticall progress the instances. Parameters ---------- new_id : ID to set the context to """ assert ( new_id in self.context_selector.context_ids ), "Unknown ID, this context does not exist in the context set." self.context_selector.context_id = new_id self.context_selector.context = self.context_selector.contexts[new_id] self.context = self.context_selector.context self._update_context()
[docs] def get_observation_space( self, obs_context_feature_names: list[str] | None = None ) -> gymnasium.spaces.Dict: """Get the observation space for the context. Parameters ---------- obs_context_feature_names : list[str] | None, optional Name of the context features to be included in the observation, by default None. If it is None, we add all context features. Returns ------- gymnasium.spaces.Dict Gymnasium observation space which contains the observation space of the underlying environment ("state") and for the context ("context"). """ context_space = self.get_context_space() obs_space_context = context_space.to_gymnasium_space( context_feature_names=obs_context_feature_names, as_dict=self.obs_context_as_dict, ) obs_space = spaces.Dict( { "obs": self.base_observation_space, "context": obs_space_context, } ) return obs_space
[docs] @staticmethod @abc.abstractmethod def get_context_features() -> dict[str, ContextFeature]: """Get the context features Defined per environment. Returns ------- dict[str, ContextFeature] Context feature definitions """ ...
[docs] @classmethod def get_context_space(cls) -> ContextSpace: """Get context space Returns ------- ContextSpace Context space with utility methods holding information about defaults, types, bounds, etc. """ return ContextSpace(cls.get_context_features())
[docs] @classmethod def get_default_context(cls) -> Context: """Get the default context Returns ------- Context Default context. """ default_context = cls.get_context_space().get_default_context() return default_context
def _progress_instance(self) -> None: """ Progress instance. In this case instance is a specific context. 1. Select instance with the instance_mode. If the instance_mode is random, randomly select the next instance from the set of contexts. If instance_mode is rr or roundrobin, select the next instance. Returns ------- None """ context = self.context_selector.select() # type: ignore [call-arg] self.context = context
[docs] def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[Any, dict[str, Any]]: """Reset the environment. First, we progress the instance, i.e. select a new context with the context selector. Then we update the context in the wrapped environment. Finally, we reset the underlying environment and add context information to the observation. Parameters ---------- seed : int | None, optional Seed, by default None options : dict[str, Any] | None, optional Options, by default None Returns ------- tuple[Any, dict[str, Any]] Observation, info. """ last_context_id = self.context_id self._progress_instance() if self.context_id != last_context_id: self._update_context() state, info = super().reset(seed=seed, options=options) state = self._add_context_to_state(state) info["context_id"] = self.context_id return state, info
def _add_context_to_state(self, state: Any) -> dict[str, Any]: """Add context observation to the state The state is the observation from the underlying environment and we add the context information to it. We return a dictionary of the state and context, and the context is maybe represented as a dictionary itself (controlled via `self.obs_context_as_dict`). Parameters ---------- state : Any State from the environment Returns ------- dict[str, Any] State context observation dict """ if not self.obs_context_as_dict: context = [self.context[k] for k in self.obs_context_features] else: context = { k: v for k, v in self.context.items() if k in self.obs_context_features } state_context_dict = { "obs": state, "context": context, } return state_context_dict @abc.abstractmethod def _update_context(self) -> None: """ Update the context feature values of the environment. `self._progress_instance` must be called at least once to se(lec)t a valid context. Returns ------- None """ ...
[docs] def step( self, action: Any ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: """Step the environment. The context is added to the observation returned by the wrapped environment. Parameters ---------- action : Any Action Returns ------- tuple[Any, SupportsFloat, bool, bool, dict[str, Any]] Observation, rewar, terminated, truncated, info. """ state, reward, terminated, truncated, info = super().step(action) state = self._add_context_to_state(state) info["context_id"] = self.context_id return state, reward, terminated, truncated, info