"""Wrapper to track time."""
from __future__ import annotations
import time
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
sb.set_style("darkgrid")
current_palette = list(sb.color_palette())
[docs]
class EpisodeTimeWrapper(Wrapper):
    """Wrapper to track time spent per episode.
    Includes interval mode that returns times in lists of len(interval)
    instead of one long list.
    """
    def __init__(self, env, time_interval=None, logger=None):
        """Initialize wrapper.
        Parameters
        ----------
        env : gym.Env
            Environment to wrap
        time_interval : int
            If not none, mean in given intervals is tracked, too
        logger : dacbench.logger.ModuleLogger
            logger to write to
        """
        super().__init__(env)
        self.time_interval = time_interval
        self.all_steps = []
        if self.time_interval:
            self.step_intervals = []
            self.current_step_interval = []
        self.overall_times = []
        self.episode = []
        if self.time_interval:
            self.time_intervals = []
            self.current_times = []
        self.logger = logger
[docs]
    def __setattr__(self, name, value):
        """Set attribute in wrapper if available and in env if not.
        Parameters
        ----------
        name : str
            Attribute to set
        value
            Value to set attribute to
        """
        if name in [
            "time_interval",
            "overall_times",
            "time_intervals",
            "current_times",
            "env",
            "get_times",
            "step",
            "render_step_time",
            "render_episode_time",
            "reset",
            "episode",
            "all_steps",
            "current_step_interval",
            "step_intervals",
            "logger",
        ]:
            object.__setattr__(self, name, value)
        else:
            setattr(self.env, name, value) 
[docs]
    def __getattribute__(self, name):
        """Get attribute value of wrapper if available and of env if not.
        Parameters
        ----------
        name : str
            Attribute to get
        Returns:
        -------
        value
            Value of given name
        """
        if name in [
            "time_interval",
            "overall_times",
            "time_intervals",
            "current_times",
            "env",
            "get_times",
            "step",
            "render_step_time",
            "render_episode_time",
            "reset",
            "episode",
            "all_steps",
            "current_step_interval",
            "step_intervals",
            "logger",
        ]:
            return object.__getattribute__(self, name)
        return getattr(self.env, name) 
[docs]
    def step(self, action):
        """Execute environment step and record time.
        Parameters
        ----------
        action : int
            action to execute
        Returns:
        -------
        np.array, float, bool, bool, dict
            state, reward, terminated, truncated, metainfo
        """
        start = time.time()
        state, reward, terminated, truncated, info = self.env.step(action)
        stop = time.time()
        duration = stop - start
        self.episode.append(duration)
        self.all_steps.append(duration)
        if self.logger is not None:
            self.logger.log("step_duration", duration)
        if self.time_interval:
            if len(self.current_step_interval) < self.time_interval:
                self.current_step_interval.append(duration)
            else:
                self.step_intervals.append(self.current_step_interval)
                self.current_step_interval = [duration]
        if terminated or truncated:
            self.overall_times.append(self.episode)
            if self.logger is not None:
                self.logger.log("episode_duration", sum(self.episode))
            if self.time_interval:
                if len(self.current_times) < self.time_interval:
                    self.current_times.append(self.episode)
                else:
                    self.time_intervals.append(self.current_times)
                    self.current_times = []
            self.episode = []
        return state, reward, terminated, truncated, info 
[docs]
    def get_times(self):
        """Get times.
        Returns:
        -------
        np.array or np.array, np.array
            all times or all times and interval sorted times
        """
        if self.time_interval:
            complete_intervals = [*self.time_intervals, self.current_times]
            complete_step_intervals = [*self.step_intervals, self.current_step_interval]
            return (
                self.overall_times,
                self.all_steps,
                complete_intervals,
                complete_step_intervals,
            )
        return np.array(self.overall_times), np.array(self.all_steps) 
[docs]
    def render_step_time(self):
        """Render step times."""
        figure = plt.figure(figsize=(12, 6))
        canvas = FigureCanvas(figure)
        plt.title("Time per Step")
        plt.xlabel("Step")
        plt.ylabel("Time (s)")
        plt.plot(
            np.arange(len(self.all_steps)), self.all_steps, label="Step time", color="g"
        )
        if self.time_interval:
            interval_means = [np.mean(interval) for interval in self.step_intervals] + [
                np.mean(self.current_step_interval)
            ]
            plt.plot(
                np.arange(len(self.step_intervals) + 2) * self.time_interval,
                [interval_means[0], *interval_means],
                label="Mean interval time",
                color="orange",
            )
        plt.legend(loc="upper right")
        canvas.draw()
        width, height = figure.get_size_inches() * figure.get_dpi()
        return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape(
            int(height), int(width), 3
        ) 
        # plt.close(figure)
[docs]
    def render_episode_time(self):
        """Render episode times."""
        figure = plt.figure(figsize=(12, 6))
        canvas = FigureCanvas(figure)
        plt.title("Time per Episode")
        plt.xlabel("Episode")
        plt.ylabel("Time (s)")
        plt.plot(
            np.arange(len(self.overall_times)),
            [sum(episode) for episode in self.overall_times],
            label="Episode time",
            color="g",
        )
        if self.time_interval:
            interval_sums = []
            for interval in self.time_intervals:
                ep_times = []
                for episode in interval:
                    ep_times.append(sum(episode))
                interval_sums.append(np.mean(ep_times))
            interval_sums += [np.mean([sum(episode) for episode in self.current_times])]
            plt.plot(
                np.arange(len(self.time_intervals) + 2) * self.time_interval,
                [interval_sums[0], *interval_sums],
                label="Mean interval time",
                color="orange",
            )
        plt.legend(loc="upper right")
        canvas.draw()
        width, height = figure.get_size_inches() * figure.get_dpi()
        return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape(
            int(height), int(width), 3
        )