"""Wrapper for instance sampling."""
from __future__ import annotations
import numpy as np
from gymnasium import Wrapper
from scipy.stats import norm
[docs]
class InstanceSamplingWrapper(Wrapper):
    """Wrapper to sample a new instance at a given time point.
    Instances can either be sampled using a given method
    or a distribution infered from a given list of instances.
    """
    def __init__(self, env, sampling_function=None, instances=None, reset_interval=0):
        """Initialize wrapper.
        Either sampling_function or instances must be given
        Parameters
        ----------
        env : gym.Env
            Environment to wrap
        sampling_function : function
            Function to sample instances from
        instances : list
            List of instances to infer distribution from
        reset_interval : int
            additional episodes for which to keep an instance
        """
        super().__init__(env)
        if sampling_function:
            self.sampling_function = sampling_function
        elif instances:
            self.sampling_function = self.fit_dist(instances)
        else:
            raise Exception("No distribution to sample from given")
        self.reset_interval = reset_interval
        self.reset_tracker = 0
[docs]
    def __setattr__(self, name, value):
        """Set attribute in wrapper if available and in env if not.
        Parameters
        ----------
        name : str
            Attribute to set
        value
            Value to set attribute to
        """
        if name in ["sampling_function", "env", "fit_dist", "reset"]:
            object.__setattr__(self, name, value)
        else:
            setattr(self.env, name, value) 
[docs]
    def __getattribute__(self, name):
        """Get attribute value of wrapper if available and of env if not.
        Parameters
        ----------
        name : str
            Attribute to get
        Returns:
        -------
        value
            Value of given name
        """
        if name in ["sampling_function", "env", "fit_dist", "reset"]:
            return object.__getattribute__(self, name)
        return getattr(self.env, name) 
[docs]
    def reset(self, seed=None, options=None):
        """Reset environment and use sampled instance for training.
        Returns:
        -------
        np.array
            state
        """
        if options is None:
            options = {}
        if self.reset_tracker >= self.reset_interval:
            instance = self.sampling_function()
            self.env.use_next_instance(instance=instance)
        return self.env.reset(seed=seed, options=options) 
[docs]
    def fit_dist(self, instances):
        """Approximate instance distribution in given instance set.
        Parameters
        ----------
        instances : List
            instance set
        Returns:
        -------
        method
            sampling method for new instances
        """
        dists = []
        for k in instances[0].__dict__:
            component = [instances[i].__dict__[k] for i in instances]
            dist = norm.fit(component)
            dists.append(dist)
        def sample():
            instance = []
            rng = np.random.default_rng()
            for d in dists:
                instance.append(rng.normal(d[0], d[1]))
            return instance
        return sample