"""Wrapper for instance sampling."""
from __future__ import annotations
import numpy as np
from gymnasium import Wrapper
from scipy.stats import norm
[docs]
class InstanceSamplingWrapper(Wrapper):
"""Wrapper to sample a new instance at a given time point.
Instances can either be sampled using a given method
or a distribution infered from a given list of instances.
"""
def __init__(self, env, sampling_function=None, instances=None, reset_interval=0):
"""Initialize wrapper.
Either sampling_function or instances must be given
Parameters
----------
env : gym.Env
Environment to wrap
sampling_function : function
Function to sample instances from
instances : list
List of instances to infer distribution from
reset_interval : int
additional episodes for which to keep an instance
"""
super().__init__(env)
if sampling_function:
self.sampling_function = sampling_function
elif instances:
self.sampling_function = self.fit_dist(instances)
else:
raise Exception("No distribution to sample from given")
self.reset_interval = reset_interval
self.reset_tracker = 0
[docs]
def __setattr__(self, name, value):
"""Set attribute in wrapper if available and in env if not.
Parameters
----------
name : str
Attribute to set
value
Value to set attribute to
"""
if name in ["sampling_function", "env", "fit_dist", "reset"]:
object.__setattr__(self, name, value)
else:
setattr(self.env, name, value)
[docs]
def __getattribute__(self, name):
"""Get attribute value of wrapper if available and of env if not.
Parameters
----------
name : str
Attribute to get
Returns:
-------
value
Value of given name
"""
if name in ["sampling_function", "env", "fit_dist", "reset"]:
return object.__getattribute__(self, name)
return getattr(self.env, name)
[docs]
def reset(self, seed=None, options=None):
"""Reset environment and use sampled instance for training.
Returns:
-------
np.array
state
"""
if options is None:
options = {}
if self.reset_tracker >= self.reset_interval:
instance = self.sampling_function()
self.env.use_next_instance(instance=instance)
return self.env.reset(seed=seed, options=options)
[docs]
def fit_dist(self, instances):
"""Approximate instance distribution in given instance set.
Parameters
----------
instances : List
instance set
Returns:
-------
method
sampling method for new instances
"""
dists = []
for k in instances[0].__dict__:
component = [instances[i].__dict__[k] for i in instances]
dist = norm.fit(component)
dists.append(dist)
def sample():
instance = []
rng = np.random.default_rng()
for d in dists:
instance.append(rng.normal(d[0], d[1]))
return instance
return sample