Source code for dacbench.wrappers.policy_progress_wrapper

"""Wrapper for process tracking."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from gymnasium import Wrapper


[docs] class PolicyProgressWrapper(Wrapper): """Wrapper to track progress towards optimal policy. Can only be used if a way to obtain the optimal policy given an instance can be obtained. """ def __init__(self, env, compute_optimal): """Initialize wrapper. Parameters ---------- env : gym.Env Environment to wrap compute_optimal : function Function to compute optimal policy """ super().__init__(env) self.compute_optimal = compute_optimal self.episode = [] self.policy_progress = []
[docs] def __setattr__(self, name, value): """Set attribute in wrapper if available and in env if not. Parameters ---------- name : str Attribute to set value Value to set attribute to """ if name in [ "compute_optimal", "env", "episode", "policy_progress", "render_policy_progress", ]: object.__setattr__(self, name, value) else: setattr(self.env, name, value)
[docs] def __getattribute__(self, name): """Get attribute value of wrapper if available and of env if not. Parameters ---------- name : str Attribute to get Returns: ------- value Value of given name """ if name in [ "step", "compute_optimal", "env", "episode", "policy_progress", "render_policy_progress", ]: return object.__getattribute__(self, name) return getattr(self.env, name)
[docs] def step(self, action): """Execute environment step and record distance. Parameters ---------- action : int action to execute Returns: ------- np.array, float, bool, bool, dict state, reward, terminated, truncated, metainfo """ state, reward, terminated, truncated, info = self.env.step(action) if isinstance(action, dict): action = list(action.values()) for i in range(len(action)): if isinstance(action[i], list | np.ndarray): action[i] = action[i][0] print(action) self.episode.append(action) if terminated or truncated: optimal = self.compute_optimal(self.env.instance) self.policy_progress.append( np.linalg.norm(np.array(optimal) - np.array(self.episode)) ) self.episode = [] return state, reward, terminated, truncated, info
[docs] def render_policy_progress(self): """Plot progress.""" plt.figure(figsize=(12, 6)) plt.plot(np.arange(len(self.policy_progress)), self.policy_progress) plt.title("Policy progress over time") plt.xlabel("Episode") plt.ylabel("Distance to optimal policy") plt.show()