Source code for dacbench.wrappers.state_tracking_wrapper

"""Wrapper for the state tracking."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper, spaces
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

sb.set_style("darkgrid")
current_palette = list(sb.color_palette())


[docs] class StateTrackingWrapper(Wrapper): """Wrapper to track state changed over time. Includes interval mode that returns states in lists of len(interval) instead of one long list. """ def __init__(self, env, state_interval=None, logger=None): """Initialize wrapper. Parameters ---------- env : gym.Env Environment to wrap state_interval : int If not none, mean in given intervals is tracked, too logger : dacbench.logger.ModuleLogger logger to write to """ super().__init__(env) self.state_interval = state_interval self.overall_states = [] if self.state_interval: self.state_intervals = [] self.current_states = [] self.episode_states = None self.state_type = type(env.observation_space) self.logger = logger if self.logger is not None: benchmark_info = getattr(env, "benchmark_info", None) self.state_description = ( benchmark_info.get("state_description", None) if benchmark_info is not None else None )
[docs] def __setattr__(self, name, value): """Set attribute in wrapper if available and in env if not. Parameters ---------- name : str Attribute to set value Value to set attribute to """ if name in [ "state_interval", "overall_states", "state_intervals", "current_states", "state_type", "env", "episode_states", "get_states", "step", "reset", "render_state_tracking", "logger", ]: object.__setattr__(self, name, value) else: setattr(self.env, name, value)
[docs] def __getattribute__(self, name): """Get attribute value of wrapper if available and of env if not. Parameters ---------- name : str Attribute to get Returns: ------- value Value of given name """ if name in [ "state_interval", "overall_states", "state_intervals", "current_states", "state_type", "env", "episode_states", "get_states", "step", "reset", "render_state_tracking", "logger", ]: return object.__getattribute__(self, name) return getattr(self.env, name)
[docs] def reset(self, seed=None, options=None): """Reset environment and record starting state. Returns: ------- np.array, {} state, info """ if options is None: options = {} state, info = self.env.reset(seed=seed, options=options) self.overall_states.append(state) if self.state_interval: if len(self.current_states) < self.state_interval: self.current_states.append(state) else: self.state_intervals.append(self.current_states) self.current_states = [state] return state, info
[docs] def step(self, action): """Execute environment step and record state. Parameters ---------- action : int action to execute Returns: ------- np.array, float, bool, dict state, reward, done, metainfo """ state, reward, terminated, truncated, info = self.env.step(action) self.overall_states.append(state) if self.logger is not None: self.logger.log_space("state", state, self.state_description) if self.state_interval: if len(self.current_states) < self.state_interval: self.current_states.append(state) else: self.state_intervals.append(self.current_states) self.current_states = [state] return state, reward, terminated, truncated, info
[docs] def get_states(self): """Get state progression. Returns: ------- np.array or np.array, np.array all states or all states and interval sorted states """ if self.state_interval: complete_intervals = [*self.state_intervals, self.current_states] return self.overall_states, complete_intervals return self.overall_states
[docs] def render_state_tracking(self): """Render state progression. Returns: ------- np.array RBG data of state tracking """ def plot_single(ax=None, index=None, x=False, y=False): if ax is None: ax = plt ax.xlabel("Episode") ax.ylabel("State") else: if x: ax.set_xlabel("Episode") if y: ax.set_ylabel("State") if index is not None: ys = [state[index] for state in self.overall_states] else: ys = self.overall_states p = ax.plot( np.arange(len(self.overall_states)), ys, label="Episode state", color="g", ) p2 = None if self.state_interval: if index is not None: y_ints = [] for interval in self.state_intervals: y_ints.append([state[index] for state in interval]) else: y_ints = self.state_intervals p2 = ax.plot( np.arange(len(self.state_intervals)) * self.state_interval, [np.mean(interval) for interval in y_ints], label="Mean interval state", color="orange", ) ax.legend(loc="upper left") return p, p2 state_length_border = 5 if self.state_type == spaces.Discrete: figure = plt.figure(figsize=(20, 20)) canvas = FigureCanvas(figure) p, p2 = plot_single() canvas.draw() elif self.state_type in (spaces.Dict, spaces.Tuple): raise NotImplementedError elif self.state_type in (spaces.MultiDiscrete, spaces.MultiBinary, spaces.Box): if self.state_type == spaces.MultiDiscrete: state_length = len(self.env.observation_space.nvec) elif self.state_type == spaces.MultiBinary: state_length = self.env.observation_space.n else: state_length = len(self.env.observation_space.high) if state_length == 1: figure = plt.figure(figsize=(20, 20)) canvas = FigureCanvas(figure) p, p2 = plot_single() elif state_length < state_length_border: dim = 1 figure, axarr = plt.subplots(state_length) else: dim = state_length % 4 figure, axarr = plt.subplots(state_length % 4, state_length // dim) figure.suptitle("State over time") canvas = FigureCanvas(figure) for i in range(state_length): if state_length == 1: continue x = False if i % dim == dim - 1: x = True if state_length < state_length_border: p, p2 = plot_single(axarr[i], i, y=True, x=x) else: y = i % state_length // dim == 0 p, p2 = plot_single(axarr[i % dim, i // dim], i, x=x, y=y) canvas.draw() else: raise ValueError("Unknown state type") width, height = figure.get_size_inches() * figure.get_dpi() return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape( int(height), int(width), 3 )