"""Wrapper for action frequency."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper, spaces
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
sb.set_style("darkgrid")
current_palette = list(sb.color_palette())
[docs]
class ActionFrequencyWrapper(Wrapper):
"""Wrapper to action frequency.
Includes interval mode that returns frequencies in lists of len(interval)
instead of one long list.
"""
def __init__(self, env, action_interval=None, logger=None):
"""Initialize wrapper.
Parameters
----------
env : gym.Env
Environment to wrap
action_interval : int
If not none, mean in given intervals is tracked, too
logger: logger.ModuleLogger
logger to write to
"""
super().__init__(env)
self.action_interval = action_interval
self.overall_actions = []
if self.action_interval:
self.action_intervals = []
self.current_actions = []
self.action_space_type = type(self.env.action_space)
self.logger = logger
[docs]
def __setattr__(self, name, value):
"""Set attribute in wrapper if available and in env if not.
Parameters
----------
name : str
Attribute to set
value
Value to set attribute to
"""
if name in [
"action_interval",
"overall_actions",
"action_intervals",
"current_actions",
"env",
"get_actions",
"step",
"render_action_tracking",
"logger",
]:
object.__setattr__(self, name, value)
else:
setattr(self.env, name, value)
[docs]
def __getattribute__(self, name):
"""Get attribute value of wrapper if available and of env if not.
Parameters
----------
name : str
Attribute to get
Returns:
-------
value
Value of given name
"""
if name in [
"action_interval",
"overall_actions",
"action_intervals",
"current_actions",
"env",
"get_actions",
"step",
"render_action_tracking",
"logger",
]:
return object.__getattribute__(self, name)
return getattr(self.env, name)
[docs]
def step(self, action):
"""Execute environment step and record state.
Parameters
----------
action : int
action to execute
Returns:
-------
np.array, float, bool, dict
state, reward, done, metainfo
"""
state, reward, terminated, truncated, info = self.env.step(action)
log_action = action
if isinstance(log_action, dict):
log_action = list(log_action.values())
for i in range(len(log_action)):
if isinstance(log_action[i], list | np.ndarray):
log_action[i] = log_action[i][0]
self.overall_actions.append(log_action)
if self.logger is not None:
self.logger.log_space("action", action)
if self.action_interval:
if len(self.current_actions) < self.action_interval:
self.current_actions.append(log_action)
else:
self.action_intervals.append(self.current_actions)
self.current_actions = [log_action]
return state, reward, terminated, truncated, info
[docs]
def get_actions(self):
"""Get state progression.
Returns:
-------
np.array or np.array, np.array
all states or all states and interval sorted states
"""
if self.action_interval:
complete_intervals = [*self.action_intervals, self.current_actions]
return self.overall_actions, complete_intervals
return self.overall_actions
[docs]
def render_action_tracking(self):
"""Render action progression.
Returns:
-------
np.array
RBG data of action tracking
"""
def plot_single(ax=None, index=None, x=False, y=False):
if ax is None:
plt.xlabel("Step")
plt.ylabel("Action value")
elif x and y:
ax.set_ylabel("Action value")
ax.set_xlabel("Step")
elif x:
ax.set_xlabel("Step")
elif y:
ax.set_ylabel("Action value")
if index is not None:
ys = [state[index] for state in self.overall_actions]
else:
ys = self.overall_actions
if ax is None:
p = plt.plot(
np.arange(len(self.overall_actions)),
ys,
label="Step actions",
color="g",
)
else:
p = ax.plot(
np.arange(len(self.overall_actions)),
ys,
label="Step actions",
color="g",
)
p2 = None
if self.action_interval:
if index is not None:
y_ints = []
for interval in self.action_intervals:
y_ints.append([state[index] for state in interval])
else:
y_ints = self.action_intervals
if ax is None:
p2 = plt.plot(
np.arange(len(self.action_intervals)) * self.action_interval,
[np.mean(interval) for interval in y_ints],
label="Mean interval action",
color="orange",
)
plt.legend(loc="upper left")
else:
p2 = ax.plot(
np.arange(len(self.action_intervals)) * self.action_interval,
[np.mean(interval) for interval in y_ints],
label="Mean interval action",
color="orange",
)
ax.legend(loc="upper left")
return p, p2
action_size_border = 5
if self.action_space_type == spaces.Discrete:
figure = plt.figure(figsize=(12, 6))
canvas = FigureCanvas(figure)
p, p2 = plot_single()
canvas.draw()
elif self.action_space_type in (spaces.Dict, spaces.Tuple):
raise NotImplementedError
elif self.action_space_type in (
spaces.MultiDiscrete,
spaces.MultiBinary,
spaces.Box,
):
if self.action_space_type == spaces.MultiDiscrete:
action_size = len(self.env.action_space.nvec)
elif self.action_space_type == spaces.MultiBinary:
action_size = self.env.action_space.n
else:
action_size = len(self.env.action_space.high)
if action_size == 1:
figure = plt.figure(figsize=(12, 6))
canvas = FigureCanvas(figure)
p, p2 = plot_single()
elif action_size < action_size_border:
dim = 1
figure, axarr = plt.subplots(action_size)
else:
dim = action_size % 4
figure, axarr = plt.subplots(action_size % 4, action_size // dim)
figure.suptitle("State over time")
canvas = FigureCanvas(figure)
for i in range(action_size):
if action_size == 1:
continue
x = False
if i % dim == dim - 1:
x = True
if action_size < action_size_border:
p, p2 = plot_single(axarr[i], i, y=True, x=x)
else:
y = i % action_size // dim == 0
p, p2 = plot_single(axarr[i % dim, i // dim], i, x=x, y=y)
canvas.draw()
width, height = figure.get_size_inches() * figure.get_dpi()
return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape(
int(height), int(width), 3
)