from __future__ import annotations
import gymnasium
import pygame
from gymnasium.core import Env
from carl.context.selection import AbstractSelector
from carl.envs.carl_env import CARLEnv
from carl.utils.types import Contexts
try:
pygame.display.init()
except: # pragma: no cover
import os # pragma: no cover
os.environ["SDL_VIDEODRIVER"] = "dummy" # pragma: no cover
[docs]
class CARLGymnasiumEnv(CARLEnv):
env_name: str
render_mode: str = "rgb_array"
def __init__(
self,
env: Env | None = None,
contexts: Contexts | None = None,
obs_context_features: list[str]
| None = None, # list the context features which should be added to the state
obs_context_as_dict: bool = True,
context_selector: AbstractSelector | type[AbstractSelector] | None = None,
context_selector_kwargs: dict = None,
**kwargs,
) -> None:
"""
CARL Gymnasium Environment.
Parameters
----------
env : Env | None
Gymnasium environment, the default is None.
If None, instantiate the env with gymnasium's make function and
`self.env_name` which is defined in each child class.
contexts : Contexts | None, optional
Context set, by default None. If it is None, we build the
context set with the default context.
obs_context_features : list[str] | None, optional
Context features which should be included in the observation, by default None.
If they are None, add all context features.
context_selector: AbstractSelector | type[AbstractSelector] | None, optional
The context selector (class), after each reset selects a new context to use.
If None, use a round robin selector.
context_selector_kwargs : dict, optional
Optional keyword arguments for the context selector, by default None.
Only used when `context_selector` is not None.
Attributes
----------
env_name: str
The registered gymnasium environment name.
"""
if env is None:
env = gymnasium.make(id=self.env_name, render_mode=self.render_mode)
super().__init__(
env=env,
contexts=contexts,
obs_context_features=obs_context_features,
obs_context_as_dict=obs_context_as_dict,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
**kwargs,
)
def _update_context(self) -> None:
for k, v in self.context.items():
setattr(self.env.unwrapped, k, v)