Source code for dacbench.wrappers.performance_tracking_wrapper
"""Wrapper for performance tracking."""
from __future__ import annotations
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper
sb.set_style("darkgrid")
current_palette = list(sb.color_palette())
[docs]
class PerformanceTrackingWrapper(Wrapper):
"""Wrapper to track episode performance.
Includes interval mode that returns performance in lists of len(interval)
instead of one long list.
"""
def __init__(
self,
env,
performance_interval=None,
track_instance_performance=True,
logger=None,
):
"""Initialize wrapper.
Parameters
----------
env : gym.Env
Environment to wrap
performance_interval : int
If not none, mean in given intervals is tracked, too
track_instance_performance : bool
Indicates whether to track per-instance performance
logger : dacbench.logger.ModuleLogger
logger to write to
"""
super().__init__(env)
self.performance_interval = performance_interval
self.overall_performance = []
self.episode_performance = 0
if self.performance_interval:
self.performance_intervals = []
self.current_performance = []
self.track_instances = track_instance_performance
if self.track_instances:
self.instance_performances = defaultdict(list)
self.logger = logger
[docs]
def __setattr__(self, name, value):
"""Set attribute in wrapper if available and in env if not.
Parameters
----------
name : str
Attribute to set
value
Value to set attribute to
"""
if name in [
"performance_interval",
"track_instances",
"overall_performance",
"performance_intervals",
"current_performance",
"env",
"get_performance",
"step",
"instance_performances",
"episode_performance",
"render_performance",
"render_instance_performance",
"logger",
]:
object.__setattr__(self, name, value)
else:
setattr(self.env, name, value)
[docs]
def __getattribute__(self, name):
"""Get attribute value of wrapper if available and of env if not.
Parameters
----------
name : str
Attribute to get
Returns:
-------
value
Value of given name
"""
if name in [
"performance_interval",
"track_instances",
"overall_performance",
"performance_intervals",
"current_performance",
"env",
"get_performance",
"step",
"instance_performances",
"episode_performance",
"render_performance",
"render_instance_performance",
"logger",
]:
return object.__getattribute__(self, name)
return getattr(self.env, name)
[docs]
def step(self, action):
"""Execute environment step and record performance.
Parameters
----------
action : int
action to execute
Returns:
-------
np.array, float, bool, dict
state, reward, done, metainfo
"""
state, reward, terminated, truncated, info = self.env.step(action)
self.episode_performance += reward
if terminated or truncated:
self.overall_performance.append(self.episode_performance)
if self.logger is not None:
self.logger.log(
"overall_performance",
self.episode_performance,
)
if self.performance_interval:
if len(self.current_performance) < self.performance_interval:
self.current_performance.append(self.episode_performance)
else:
self.performance_intervals.append(self.current_performance)
self.current_performance = [self.episode_performance]
if self.track_instances:
key = "".join(str(e) for e in self.env.instance.__dict__.values())
self.instance_performances[key].append(self.episode_performance)
self.episode_performance = 0
return state, reward, terminated, truncated, info
[docs]
def get_performance(self):
"""Get state performance.
Returns:
-------
np.array or np.array, np.array or np.array, dict or np.array, np.arry, dict
all states or all states and interval sorted states
"""
if self.performance_interval and self.track_instances:
complete_intervals = [*self.performance_intervals, self.current_performance]
return (
self.overall_performance,
complete_intervals,
self.instance_performances,
)
if self.performance_interval:
complete_intervals = [*self.performance_intervals, self.current_performance]
return self.overall_performance, complete_intervals
if self.track_instances:
return self.overall_performance, self.instance_performances
return self.overall_performance
[docs]
def render_performance(self):
"""Plot performance."""
plt.figure(figsize=(12, 6))
plt.plot(
np.arange(len(self.overall_performance) // 2),
self.overall_performance[1::2],
)
plt.title("Mean Performance per episode")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.show()
[docs]
def render_instance_performance(self):
"""Plot mean performance for each instance."""
plt.figure(figsize=(12, 6))
plt.title("Mean Performance per Instance")
plt.ylabel("Mean reward")
plt.xlabel("Instance")
ax = plt.subplot(111)
for k, i in zip(
self.instance_performances.keys(),
np.arange(len(self.instance_performances.keys())),
strict=False,
):
ax.bar(str(i), np.mean(self.instance_performances[k]))
plt.show()