"""Wrapper for process tracking."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
from gymnasium import Wrapper
[docs]
class PolicyProgressWrapper(Wrapper):
"""Wrapper to track progress towards optimal policy.
Can only be used if a way to obtain the optimal policy
given an instance can be obtained.
"""
def __init__(self, env, compute_optimal):
"""Initialize wrapper.
Parameters
----------
env : gym.Env
Environment to wrap
compute_optimal : function
Function to compute optimal policy
"""
super().__init__(env)
self.compute_optimal = compute_optimal
self.episode = []
self.policy_progress = []
[docs]
def __setattr__(self, name, value):
"""Set attribute in wrapper if available and in env if not.
Parameters
----------
name : str
Attribute to set
value
Value to set attribute to
"""
if name in [
"compute_optimal",
"env",
"episode",
"policy_progress",
"render_policy_progress",
]:
object.__setattr__(self, name, value)
else:
setattr(self.env, name, value)
[docs]
def __getattribute__(self, name):
"""Get attribute value of wrapper if available and of env if not.
Parameters
----------
name : str
Attribute to get
Returns:
-------
value
Value of given name
"""
if name in [
"step",
"compute_optimal",
"env",
"episode",
"policy_progress",
"render_policy_progress",
]:
return object.__getattribute__(self, name)
return getattr(self.env, name)
[docs]
def step(self, action):
"""Execute environment step and record distance.
Parameters
----------
action : int
action to execute
Returns:
-------
np.array, float, bool, bool, dict
state, reward, terminated, truncated, metainfo
"""
state, reward, terminated, truncated, info = self.env.step(action)
if isinstance(action, dict):
action = list(action.values())
for i in range(len(action)):
if isinstance(action[i], list | np.ndarray):
action[i] = action[i][0]
print(action)
self.episode.append(action)
if terminated or truncated:
optimal = self.compute_optimal(self.env.instance)
self.policy_progress.append(
np.linalg.norm(np.array(optimal) - np.array(self.episode))
)
self.episode = []
return state, reward, terminated, truncated, info
[docs]
def render_policy_progress(self):
"""Plot progress."""
plt.figure(figsize=(12, 6))
plt.plot(np.arange(len(self.policy_progress)), self.policy_progress)
plt.title("Policy progress over time")
plt.xlabel("Episode")
plt.ylabel("Distance to optimal policy")
plt.show()