"""Wrapper for the state tracking."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper, spaces
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
sb.set_style("darkgrid")
current_palette = list(sb.color_palette())
[docs]
class StateTrackingWrapper(Wrapper):
"""Wrapper to track state changed over time.
Includes interval mode that returns states in lists of len(interval)
instead of one long list.
"""
def __init__(self, env, state_interval=None, logger=None):
"""Initialize wrapper.
Parameters
----------
env : gym.Env
Environment to wrap
state_interval : int
If not none, mean in given intervals is tracked, too
logger : dacbench.logger.ModuleLogger
logger to write to
"""
super().__init__(env)
self.state_interval = state_interval
self.overall_states = []
if self.state_interval:
self.state_intervals = []
self.current_states = []
self.episode_states = None
self.state_type = type(env.observation_space)
self.logger = logger
if self.logger is not None:
benchmark_info = getattr(env, "benchmark_info", None)
self.state_description = (
benchmark_info.get("state_description", None)
if benchmark_info is not None
else None
)
[docs]
def __setattr__(self, name, value):
"""Set attribute in wrapper if available and in env if not.
Parameters
----------
name : str
Attribute to set
value
Value to set attribute to
"""
if name in [
"state_interval",
"overall_states",
"state_intervals",
"current_states",
"state_type",
"env",
"episode_states",
"get_states",
"step",
"reset",
"render_state_tracking",
"logger",
]:
object.__setattr__(self, name, value)
else:
setattr(self.env, name, value)
[docs]
def __getattribute__(self, name):
"""Get attribute value of wrapper if available and of env if not.
Parameters
----------
name : str
Attribute to get
Returns:
-------
value
Value of given name
"""
if name in [
"state_interval",
"overall_states",
"state_intervals",
"current_states",
"state_type",
"env",
"episode_states",
"get_states",
"step",
"reset",
"render_state_tracking",
"logger",
]:
return object.__getattribute__(self, name)
return getattr(self.env, name)
[docs]
def reset(self, seed=None, options=None):
"""Reset environment and record starting state.
Returns:
-------
np.array, {}
state, info
"""
if options is None:
options = {}
state, info = self.env.reset(seed=seed, options=options)
self.overall_states.append(state)
if self.state_interval:
if len(self.current_states) < self.state_interval:
self.current_states.append(state)
else:
self.state_intervals.append(self.current_states)
self.current_states = [state]
return state, info
[docs]
def step(self, action):
"""Execute environment step and record state.
Parameters
----------
action : int
action to execute
Returns:
-------
np.array, float, bool, dict
state, reward, done, metainfo
"""
state, reward, terminated, truncated, info = self.env.step(action)
self.overall_states.append(state)
if self.logger is not None:
self.logger.log_space("state", state, self.state_description)
if self.state_interval:
if len(self.current_states) < self.state_interval:
self.current_states.append(state)
else:
self.state_intervals.append(self.current_states)
self.current_states = [state]
return state, reward, terminated, truncated, info
[docs]
def get_states(self):
"""Get state progression.
Returns:
-------
np.array or np.array, np.array
all states or all states and interval sorted states
"""
if self.state_interval:
complete_intervals = [*self.state_intervals, self.current_states]
return self.overall_states, complete_intervals
return self.overall_states
[docs]
def render_state_tracking(self):
"""Render state progression.
Returns:
-------
np.array
RBG data of state tracking
"""
def plot_single(ax=None, index=None, x=False, y=False):
if ax is None:
ax = plt
ax.xlabel("Episode")
ax.ylabel("State")
else:
if x:
ax.set_xlabel("Episode")
if y:
ax.set_ylabel("State")
if index is not None:
ys = [state[index] for state in self.overall_states]
else:
ys = self.overall_states
p = ax.plot(
np.arange(len(self.overall_states)),
ys,
label="Episode state",
color="g",
)
p2 = None
if self.state_interval:
if index is not None:
y_ints = []
for interval in self.state_intervals:
y_ints.append([state[index] for state in interval])
else:
y_ints = self.state_intervals
p2 = ax.plot(
np.arange(len(self.state_intervals)) * self.state_interval,
[np.mean(interval) for interval in y_ints],
label="Mean interval state",
color="orange",
)
ax.legend(loc="upper left")
return p, p2
state_length_border = 5
if self.state_type == spaces.Discrete:
figure = plt.figure(figsize=(20, 20))
canvas = FigureCanvas(figure)
p, p2 = plot_single()
canvas.draw()
elif self.state_type in (spaces.Dict, spaces.Tuple):
raise NotImplementedError
elif self.state_type in (spaces.MultiDiscrete, spaces.MultiBinary, spaces.Box):
if self.state_type == spaces.MultiDiscrete:
state_length = len(self.env.observation_space.nvec)
elif self.state_type == spaces.MultiBinary:
state_length = self.env.observation_space.n
else:
state_length = len(self.env.observation_space.high)
if state_length == 1:
figure = plt.figure(figsize=(20, 20))
canvas = FigureCanvas(figure)
p, p2 = plot_single()
elif state_length < state_length_border:
dim = 1
figure, axarr = plt.subplots(state_length)
else:
dim = state_length % 4
figure, axarr = plt.subplots(state_length % 4, state_length // dim)
figure.suptitle("State over time")
canvas = FigureCanvas(figure)
for i in range(state_length):
if state_length == 1:
continue
x = False
if i % dim == dim - 1:
x = True
if state_length < state_length_border:
p, p2 = plot_single(axarr[i], i, y=True, x=x)
else:
y = i % state_length // dim == 0
p, p2 = plot_single(axarr[i % dim, i // dim], i, x=x, y=y)
canvas.draw()
else:
raise ValueError("Unknown state type")
width, height = figure.get_size_inches() * figure.get_dpi()
return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape(
int(height), int(width), 3
)