Source code for dacbench.wrappers.performance_tracking_wrapper
"""Wrapper for performance tracking."""
from __future__ import annotations
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper
sb.set_style("darkgrid")
current_palette = list(sb.color_palette())
[docs]
class PerformanceTrackingWrapper(Wrapper):
    """Wrapper to track episode performance.
    Includes interval mode that returns performance in lists of len(interval)
    instead of one long list.
    """
    def __init__(
        self,
        env,
        performance_interval=None,
        track_instance_performance=True,
        logger=None,
    ):
        """Initialize wrapper.
        Parameters
        ----------
        env : gym.Env
            Environment to wrap
        performance_interval : int
            If not none, mean in given intervals is tracked, too
        track_instance_performance : bool
            Indicates whether to track per-instance performance
        logger : dacbench.logger.ModuleLogger
            logger to write to
        """
        super().__init__(env)
        self.performance_interval = performance_interval
        self.overall_performance = []
        self.episode_performance = 0
        if self.performance_interval:
            self.performance_intervals = []
            self.current_performance = []
        self.track_instances = track_instance_performance
        if self.track_instances:
            self.instance_performances = defaultdict(list)
        self.logger = logger
[docs]
    def __setattr__(self, name, value):
        """Set attribute in wrapper if available and in env if not.
        Parameters
        ----------
        name : str
            Attribute to set
        value
            Value to set attribute to
        """
        if name in [
            "performance_interval",
            "track_instances",
            "overall_performance",
            "performance_intervals",
            "current_performance",
            "env",
            "get_performance",
            "step",
            "instance_performances",
            "episode_performance",
            "render_performance",
            "render_instance_performance",
            "logger",
        ]:
            object.__setattr__(self, name, value)
        else:
            setattr(self.env, name, value)
[docs]
    def __getattribute__(self, name):
        """Get attribute value of wrapper if available and of env if not.
        Parameters
        ----------
        name : str
            Attribute to get
        Returns:
        -------
        value
            Value of given name
        """
        if name in [
            "performance_interval",
            "track_instances",
            "overall_performance",
            "performance_intervals",
            "current_performance",
            "env",
            "get_performance",
            "step",
            "instance_performances",
            "episode_performance",
            "render_performance",
            "render_instance_performance",
            "logger",
        ]:
            return object.__getattribute__(self, name)
        return getattr(self.env, name)
[docs]
    def step(self, action):
        """Execute environment step and record performance.
        Parameters
        ----------
        action : int
            action to execute
        Returns:
        -------
        np.array, float, bool, dict
            state, reward, done, metainfo
        """
        state, reward, terminated, truncated, info = self.env.step(action)
        self.episode_performance += reward
        if terminated or truncated:
            self.overall_performance.append(self.episode_performance)
            if self.logger is not None:
                self.logger.log(
                    "overall_performance",
                    self.episode_performance,
                )
            if self.performance_interval:
                if len(self.current_performance) < self.performance_interval:
                    self.current_performance.append(self.episode_performance)
                else:
                    self.performance_intervals.append(self.current_performance)
                    self.current_performance = [self.episode_performance]
            if self.track_instances:
                key = "".join(str(e) for e in self.env.instance.__dict__.values())
                self.instance_performances[key].append(self.episode_performance)
            self.episode_performance = 0
        return state, reward, terminated, truncated, info
[docs]
    def get_performance(self):
        """Get state performance.
        Returns:
        -------
        np.array or np.array, np.array or np.array, dict or np.array, np.arry, dict
            all states or all states and interval sorted states
        """
        if self.performance_interval and self.track_instances:
            complete_intervals = [*self.performance_intervals, self.current_performance]
            return (
                self.overall_performance,
                complete_intervals,
                self.instance_performances,
            )
        if self.performance_interval:
            complete_intervals = [*self.performance_intervals, self.current_performance]
            return self.overall_performance, complete_intervals
        if self.track_instances:
            return self.overall_performance, self.instance_performances
        return self.overall_performance
[docs]
    def render_performance(self):
        """Plot performance."""
        plt.figure(figsize=(12, 6))
        plt.plot(
            np.arange(len(self.overall_performance) // 2),
            self.overall_performance[1::2],
        )
        plt.title("Mean Performance per episode")
        plt.xlabel("Episode")
        plt.ylabel("Reward")
        plt.show()
[docs]
    def render_instance_performance(self):
        """Plot mean performance for each instance."""
        plt.figure(figsize=(12, 6))
        plt.title("Mean Performance per Instance")
        plt.ylabel("Mean reward")
        plt.xlabel("Instance")
        ax = plt.subplot(111)
        for k, i in zip(
            self.instance_performances.keys(),
            np.arange(len(self.instance_performances.keys())),
            strict=False,
        ):
            ax.bar(str(i), np.mean(self.instance_performances[k]))
        plt.show()