Source code for dacbench.wrappers.observation_wrapper

"""Wrapper that converts observation spaces to spaces.Box."""

from __future__ import annotations

import numpy as np
from gymnasium import Wrapper, spaces


[docs] class ObservationWrapper(Wrapper): """Wrapper convert observations spaces to spaces.Box for convenience. Currently only supports Dict -> Box """ def __init__(self, env): """Initialize wrapper. Parameters ---------- env : gym.Env Environment to wrap compute_optimal : function Function to compute optimal policy """ super().__init__(env) obs_sample = self.flatten(self.env.observation_space.sample()) size = len(obs_sample) self.observation_space = spaces.Box( low=-np.inf * np.ones(size), high=np.inf * np.ones(size) )
[docs] def __setattr__(self, name, value): """Set attribute in wrapper if available and in env if not. Parameters ---------- name : str Attribute to set value Value to set attribute to """ if name in ["observation_space", "step", "env", "flatten", "reset"]: object.__setattr__(self, name, value) else: setattr(self.env, name, value)
[docs] def __getattribute__(self, name): """Get attribute value of wrapper if available and of env if not. Parameters ---------- name : str Attribute to get Returns: ------- value Value of given name """ if name in ["observation_space", "step", "env", "flatten", "reset"]: return object.__getattribute__(self, name) return getattr(self.env, name)
[docs] def step(self, action): """Execute environment step and record distance. Parameters ---------- action : int action to execute Returns: ------- np.array, float, bool, bool, dict state, reward, terminated, truncated, metainfo """ state, reward, terminated, truncated, info = self.env.step(action) state = self.flatten(state) return state, reward, terminated, truncated, info
[docs] def reset(self, seed=None, options=None): """Execute environment step and record distance. Returns: ------- np.array, dict state, info """ if options is None: options = {} state, info = self.env.reset(seed, options) state = self.flatten(state) return state, info
[docs] def flatten(self, state_dict): """Flatten dict to list.""" keys = sorted(state_dict.keys()) values = [] for k in keys: if isinstance(state_dict[k], np.ndarray): for s in state_dict[k]: values.append(s) else: values.append(state_dict[k]) return np.array(values).astype(np.float32)