Source code for dacbench.abstract_benchmark

"""Abstract Benchmark."""

from __future__ import annotations

import json
from abc import ABC, abstractmethod
from functools import partial
from types import FunctionType

import numpy as np
from ConfigSpace import ConfigurationSpace
from gymnasium import spaces

from dacbench import wrappers


[docs] class AbstractBenchmark(ABC): """Abstract template for benchmark classes.""" def __init__(self, config_path=None, config: objdict = None): """Initialize benchmark class. Parameters ---------- config_path : str Path to load configuration from (if read from file) config : objdict Object dict containing the config """ if config is not None and config_path is not None: raise ValueError("Both path to config and config where provided") self.wrap_funcs = [] if config_path: self.config_path = config_path self.read_config_file(self.config_path) elif config: self.load_config(config) else: self.config = None
[docs] def get_config(self): """Return current configuration. Returns: -------- dict: Current config """ return self.config
[docs] def serialize_config(self): """Save configuration to json. Parameters ---------- path : str File to save config to """ conf = self.config.copy() if "observation_space_type" in self.config: conf["observation_space_type"] = f"{self.config['observation_space_type']}" if isinstance(conf["observation_space_args"][0], dict): conf["observation_space_args"] = self.jsonify_dict_space( conf["observation_space_args"][0] ) elif "observation_space" in self.config: conf["observation_space"] = self.space_to_list(conf["observation_space"]) if "action_space" in self.config: conf["action_space"] = self.space_to_list(conf["action_space"]) if "config_space" in self.config: conf["config_space"] = self.config.config_space.to_serialized_dict() conf = AbstractBenchmark.__stringify_functions(conf) for k in self.config: if isinstance(self.config[k], list | np.ndarray): if isinstance(self.config[k][0], np.ndarray): conf[k] = list(map(list, conf[k])) for i in range(len(conf[k])): if ( not isinstance(conf[k][i][0], float) and np.inf not in conf[k][i] and -np.inf not in conf[k][i] ): conf[k][i] = list(map(int, conf[k][i])) elif isinstance(conf[k], np.ndarray): conf[k] = conf[k].tolist() conf["wrappers"] = self.jsonify_wrappers() # can be recovered from instance_set_path, and could contain function that # are not serializable if "instance_set" in conf: del conf["instance_set"] return conf
[docs] @classmethod def from_json(cls, json_config): """Get config from json dict.""" config = objdict(json.loads(json_config)) if "config_space" in config: configuration_space = ConfigurationSpace.from_serialized_dict( config["config_space"] ) config.config_space = configuration_space return cls(config=config)
[docs] def to_json(self): """Write config to json.""" conf = self.serialize_config() return json.dumps(conf)
[docs] def save_config(self, path): """Write config to path.""" conf = self.serialize_config() with open(path, "w") as fp: json.dump(conf, fp, default=lambda o: "not serializable")
[docs] def jsonify_wrappers(self): """Write wrapper description to list. Returns: -------- list """ wrappers = [] for func in self.wrap_funcs: args = func.args arg_descriptions = [] contains_func = False func_dict = {} for i in range(len(args)): if callable(args[i]): contains_func = True func_dict[f"{args[i]}"] = [args[i].__module__, args[i].__name__] arg_descriptions.append(["function", f"{args[i]}"]) # elif isinstance(args[i], ModuleLogger): # pass else: arg_descriptions.append({args[i]}) function = func.func.__name__ if contains_func: wrappers.append([function, arg_descriptions, func_dict]) else: wrappers.append([function, arg_descriptions]) return wrappers
[docs] def dejson_wrappers(self, wrapper_list): """Load wrapper from list. Parameters ---------- wrapper_list : list wrapper description to load """ for i in range(len(wrapper_list)): import importlib func = getattr(wrappers, wrapper_list[i][0]) arg_descriptions = wrapper_list[i][1] args = [] for a in arg_descriptions: if a[0] == "function": module = importlib.import_module(wrapper_list[i][2][a[1]][0]) name = wrapper_list[i][2][a[1]][0] func = getattr(module, name) args.append(func) # elif a[0] == "logger": # pass else: args.append(a) self.wrap_funcs.append(partial(func, *args))
@staticmethod def __import_from(module: str, name: str): """Imports the class / function / ... with name from module. Parameters ---------- module : str module to import from name : str name to import Returns: ------- the imported object """ module = __import__(module, fromlist=[name]) return getattr(module, name)
[docs] @classmethod def class_to_str(cls): """Get string name from class.""" return cls.__module__, cls.__name__
@staticmethod def __decorate_config_with_functions(conf: dict): """Replaced the stringified functions with the callable objects. Parameters ---------- conf : config to parse """ for key, value in { k: v for k, v in conf.items() if isinstance(v, list) and len(v) == 3 and v[0] == "function" }.items(): _, module_name, function_name = value conf[key] = AbstractBenchmark.__import_from(module_name, function_name) return conf @staticmethod def __stringify_functions(conf: dict) -> dict: """Replaced all callables in the config with a triple ('function', module_name, function_name). Parameters ---------- conf : dict config to parse Returns: ------- modified dict """ for key, _value in { k: v for k, v in conf.items() if isinstance(v, FunctionType) }.items(): conf[key] = ["function", conf[key].__module__, conf[key].__name__] return conf
[docs] def space_to_list(self, space): """Make list from gym space. Parameters ---------- space: gym.spaces.Space space to parse """ res = [] if isinstance(space, spaces.Box): res.append("Box") res.append([space.low.tolist(), space.high.tolist()]) res.append("numpy.float32") elif isinstance(space, spaces.Discrete): res.append("Discrete") res.append([space.n]) elif isinstance(space, spaces.Dict): res.append("Dict") res.append(self.jsonify_dict_space(space.spaces)) elif isinstance(space, spaces.MultiDiscrete): res.append("MultiDiscrete") res.append([space.nvec]) elif isinstance(space, spaces.MultiBinary): res.append("MultiBinary") res.append([space.n]) return res
[docs] def list_to_space(self, space_list): """Make gym space from list. Parameters ---------- space_list: list list to space-ify """ if space_list[0] == "Dict": args = self.dictify_json(space_list[1]) space = getattr(spaces, space_list[0])(args) elif len(space_list) == 2: space = getattr(spaces, space_list[0])(*space_list[1]) else: typestring = space_list[2].split(".")[1] dt = getattr(np, typestring) args = [np.array(arg) for arg in space_list[1]] space = getattr(spaces, space_list[0])(*args, dtype=dt) return space
[docs] def jsonify_dict_space(self, dict_space): """Gym spaces to json dict. Parameters ---------- dict_space : dict space dict """ keys = [] types = [] arguments = [] for k in dict_space: keys.append(k) value = dict_space[k] if not isinstance(value, spaces.Box | spaces.Discrete): raise ValueError( f"Only Dict spaces made up of Box spaces or discrete spaces are " f"supported but got {type(value)}" ) if isinstance(value, spaces.Box): types.append("box") low = value.low.astype(float).tolist() high = value.high.astype(float).tolist() arguments.append([low, high]) if isinstance(value, spaces.Discrete): types.append("discrete") n = int(value.n) arguments.append([n]) return [keys, types, arguments]
[docs] def dictify_json(self, dict_list): """Json to dict structure for gym spaces. Parameters ---------- dict_list: list list of dicts """ dict_space = {} keys, types, args = dict_list for k, space_type, args_ in zip(keys, types, args, strict=False): if space_type == "box": prepared_args = map(np.array, args_) dict_space[k] = spaces.Box(*prepared_args, dtype=np.float32) elif space_type == "discrete": dict_space[k] = spaces.Discrete(*args_) else: raise TypeError( f"Currently only Discrete and Box spaces are allowed in Dict " f"spaces, got {space_type}" ) return dict_space
[docs] def load_config(self, config: objdict): """Load config. Parameters ---------- config: objdict config to load """ self.config = config if "observation_space_type" in self.config: # noqa: SIM102 # Types have to be numpy dtype (for gym spaces)s if isinstance(self.config["observation_space_type"], str): if self.config["observation_space_type"] == "None": self.config["observation_space_type"] = None else: typestring = self.config["observation_space_type"].split(" ")[1][ :-2 ] typestring = typestring.split(".")[1] self.config["observation_space_type"] = getattr(np, typestring) if "observation_space" in self.config: self.config["observation_space"] = self.list_to_space( self.config["observation_space"] ) elif "observation_space_class" in config: # noqa: SIM102 if config.observation_space_class == "Dict": self.config["observation_space_args"] = [ self.dictify_json(self.config["observation_space_args"]) ] if "action_space" in self.config: self.config["action_space"] = self.list_to_space( self.config["action_space"] ) if "wrappers" in self.config: self.dejson_wrappers(self.config["wrappers"]) del self.config["wrappers"] self.config = AbstractBenchmark.__decorate_config_with_functions(self.config) for k in self.config: if isinstance(self.config[k], list): if isinstance(self.config[k][0], list): map(np.array, self.config[k]) self.config[k] = np.array(self.config[k])
[docs] def read_config_file(self, path): """Read configuration from file. Parameters ---------- path : str Path to config file """ with open(path) as fp: config = objdict(json.load(fp)) self.load_config(config)
[docs] @abstractmethod def get_environment(self): """Make benchmark environment. Returns: -------- gym.Env: Benchmark environment """ raise NotImplementedError
[docs] def set_seed(self, seed): """Set environment seed. Parameters ---------- seed : int New seed """ self.config["seed"] = seed
[docs] def set_action_space(self, kind, args): """Change action space. Parameters ---------- kind : str Name of action space class args: list List of arguments to pass to action space class """ self.config["action_space"] = kind self.config["action_space_args"] = args
[docs] def set_observation_space(self, kind, args, data_type): """Change observation_space. Parameters ---------- kind : str Name of observation space class args : list List of arguments to pass to observation space class data_type : type Data type of observation space """ self.config["observation_space"] = kind self.config["observation_space_args"] = args self.config["observation_space_type"] = data_type
[docs] def register_wrapper(self, wrap_func): """Register wrapper. Parameters ---------- wrap_func : function wrapper init function """ if isinstance(wrap_func, list): self.wrap_funcs.append(*wrap_func) else: self.wrap_funcs.append(wrap_func)
[docs] def __eq__(self, other): """Check for equality.""" return isinstance(self, type(other)) and self.config == other.config
# This code is taken from https://goodcode.io/articles/python-dict-object/
[docs] class objdict(dict): # noqa: N801 """Modified dict to make config changes more flexible."""
[docs] def __getattr__(self, name): """Get attribute.""" if name in self: return self[name] raise AttributeError("No such attribute: " + name)
[docs] def __setattr__(self, name, value): """Set attribute.""" self[name] = value
[docs] def __delattr__(self, name): """Delete attribute.""" if name in self: del self[name] else: raise AttributeError("No such attribute: " + name)
[docs] def copy(self): """Copy self.""" return objdict(**super().copy())
[docs] def __eq__(self, other): """Check for equality.""" if not isinstance(other, dict): return False if not set(other.keys()) == set(self.keys()): return False truth = [] for key in self.keys(): if any(isinstance(obj[key], np.ndarray) for obj in (self, other)): truth.append(np.array_equal(self[key], other[key])) else: truth.append(other[key] == self[key]) return all(truth)
[docs] def __ne__(self, other): """Check for inequality.""" return not self == other