Source code for dacbench.wrappers.episode_time_tracker

"""Wrapper to track time."""

from __future__ import annotations

import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
from gymnasium import Wrapper
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

sb.set_style("darkgrid")
current_palette = list(sb.color_palette())


[docs] class EpisodeTimeWrapper(Wrapper): """Wrapper to track time spent per episode. Includes interval mode that returns times in lists of len(interval) instead of one long list. """ def __init__(self, env, time_interval=None, logger=None): """Initialize wrapper. Parameters ---------- env : gym.Env Environment to wrap time_interval : int If not none, mean in given intervals is tracked, too logger : dacbench.logger.ModuleLogger logger to write to """ super().__init__(env) self.time_interval = time_interval self.all_steps = [] if self.time_interval: self.step_intervals = [] self.current_step_interval = [] self.overall_times = [] self.episode = [] if self.time_interval: self.time_intervals = [] self.current_times = [] self.logger = logger
[docs] def __setattr__(self, name, value): """Set attribute in wrapper if available and in env if not. Parameters ---------- name : str Attribute to set value Value to set attribute to """ if name in [ "time_interval", "overall_times", "time_intervals", "current_times", "env", "get_times", "step", "render_step_time", "render_episode_time", "reset", "episode", "all_steps", "current_step_interval", "step_intervals", "logger", ]: object.__setattr__(self, name, value) else: setattr(self.env, name, value)
[docs] def __getattribute__(self, name): """Get attribute value of wrapper if available and of env if not. Parameters ---------- name : str Attribute to get Returns: ------- value Value of given name """ if name in [ "time_interval", "overall_times", "time_intervals", "current_times", "env", "get_times", "step", "render_step_time", "render_episode_time", "reset", "episode", "all_steps", "current_step_interval", "step_intervals", "logger", ]: return object.__getattribute__(self, name) return getattr(self.env, name)
[docs] def step(self, action): """Execute environment step and record time. Parameters ---------- action : int action to execute Returns: ------- np.array, float, bool, bool, dict state, reward, terminated, truncated, metainfo """ start = time.time() state, reward, terminated, truncated, info = self.env.step(action) stop = time.time() duration = stop - start self.episode.append(duration) self.all_steps.append(duration) if self.logger is not None: self.logger.log("step_duration", duration) if self.time_interval: if len(self.current_step_interval) < self.time_interval: self.current_step_interval.append(duration) else: self.step_intervals.append(self.current_step_interval) self.current_step_interval = [duration] if terminated or truncated: self.overall_times.append(self.episode) if self.logger is not None: self.logger.log("episode_duration", sum(self.episode)) if self.time_interval: if len(self.current_times) < self.time_interval: self.current_times.append(self.episode) else: self.time_intervals.append(self.current_times) self.current_times = [] self.episode = [] return state, reward, terminated, truncated, info
[docs] def get_times(self): """Get times. Returns: ------- np.array or np.array, np.array all times or all times and interval sorted times """ if self.time_interval: complete_intervals = [*self.time_intervals, self.current_times] complete_step_intervals = [*self.step_intervals, self.current_step_interval] return ( self.overall_times, self.all_steps, complete_intervals, complete_step_intervals, ) return np.array(self.overall_times), np.array(self.all_steps)
[docs] def render_step_time(self): """Render step times.""" figure = plt.figure(figsize=(12, 6)) canvas = FigureCanvas(figure) plt.title("Time per Step") plt.xlabel("Step") plt.ylabel("Time (s)") plt.plot( np.arange(len(self.all_steps)), self.all_steps, label="Step time", color="g" ) if self.time_interval: interval_means = [np.mean(interval) for interval in self.step_intervals] + [ np.mean(self.current_step_interval) ] plt.plot( np.arange(len(self.step_intervals) + 2) * self.time_interval, [interval_means[0], *interval_means], label="Mean interval time", color="orange", ) plt.legend(loc="upper right") canvas.draw() width, height = figure.get_size_inches() * figure.get_dpi() return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape( int(height), int(width), 3 )
# plt.close(figure)
[docs] def render_episode_time(self): """Render episode times.""" figure = plt.figure(figsize=(12, 6)) canvas = FigureCanvas(figure) plt.title("Time per Episode") plt.xlabel("Episode") plt.ylabel("Time (s)") plt.plot( np.arange(len(self.overall_times)), [sum(episode) for episode in self.overall_times], label="Episode time", color="g", ) if self.time_interval: interval_sums = [] for interval in self.time_intervals: ep_times = [] for episode in interval: ep_times.append(sum(episode)) interval_sums.append(np.mean(ep_times)) interval_sums += [np.mean([sum(episode) for episode in self.current_times])] plt.plot( np.arange(len(self.time_intervals) + 2) * self.time_interval, [interval_sums[0], *interval_sums], label="Mean interval time", color="orange", ) plt.legend(loc="upper right") canvas.draw() width, height = figure.get_size_inches() * figure.get_dpi() return np.fromstring(canvas.tostring_rgb(), dtype="uint8").reshape( int(height), int(width), 3 )