"""Benchmark for SGD."""
from __future__ import annotations
import math
from pathlib import Path
import ConfigSpace as CS # noqa: N817
import numpy as np
import pandas as pd
from gymnasium import spaces
from torch import nn
from dacbench.abstract_benchmark import AbstractBenchmark, objdict
from dacbench.envs import SGDEnv, SGDInstance
from dacbench.envs.env_utils import sgd_utils
DEFAULT_CFG_SPACE = CS.ConfigurationSpace()
LR = CS.Float(name="learning_rate", bounds=(0.0, 0.05))
# Value used for momentum like adaptation, as adam optimizer has no real momentum;
# "beta1" is changed
MOMENTUM = CS.Float(
name="momentum", bounds=(0.0, 1.0)
) # ! Only used, when "use_momentum" var in config true
DEFAULT_CFG_SPACE.add(LR)
DEFAULT_CFG_SPACE.add(MOMENTUM)
def __default_loss_function(**kwargs):
return nn.NLLLoss(reduction="none", **kwargs)
INFO = {
"identifier": "LR",
"name": "Learning Rate Adaption for Neural Networks",
"reward": "Negative Log Differential Validation Loss",
"state_description": [
"Step",
"Loss",
"Validation Loss",
"Crashed",
],
"action_description": ["Learning Rate", "Momentum"],
}
SGD_DEFAULTS = objdict(
{
"config_space": DEFAULT_CFG_SPACE,
"observation_space_class": "Dict",
"observation_space_type": None,
"observation_space_args": [
{
"step": spaces.Box(low=0, high=np.inf, shape=(1,)),
"loss": spaces.Box(0, np.inf, shape=(1,)),
"validationLoss": spaces.Box(low=0, high=np.inf, shape=(1,)),
"crashed": spaces.Discrete(1),
}
],
"reward_range": [-(10**9), (10**9)],
"device": "cpu",
"use_instance_generator": False,
"cutoff": 1e2,
"loss_function": __default_loss_function,
"loss_function_kwargs": {},
# "reward_function":, # Can be set, to replace the default function
# "state_method":, # Can be set, to replace the default function
"use_momentum": False,
"seed": 0,
"crash_penalty": -100.0,
"multi_agent": False,
"instance_set_path": "sgd_cifar10_variations_train.csv",
"benchmark_info": INFO,
"epoch_mode": True,
"local_model_path": False,
"dataset_config": None,
}
)
[docs]
class SGDBenchmark(AbstractBenchmark):
"""Benchmark with default configuration & relevant functions for SGD."""
def __init__(self, config_path=None, config=None):
"""Initialize SGD Benchmark.
Parameters
-------
config_path : str
Path to config file (optional)
"""
super().__init__(config_path, config)
if not self.config:
self.config = objdict(SGD_DEFAULTS.copy())
for key in SGD_DEFAULTS:
if key not in self.config:
self.config[key] = SGD_DEFAULTS[key]
[docs]
def get_environment(self):
"""Return SGDEnv env with current configuration.
Returns:
-------
SGDEnv
SGD environment
"""
if "instance_set" not in self.config:
self.read_instance_set()
# Read test set if path is specified
if "test_set" not in self.config and "test_set_path" in self.config:
self.read_instance_set(test=True)
env = SGDEnv(self.config)
for func in self.wrap_funcs:
env = func(env)
return env
[docs]
def read_instance_set(self, test=False):
"""Read path of instances from config into list."""
if test:
relative_path = Path(__file__).resolve().parent / self.config.test_set_path
absolute_path = Path(self.config.test_set_path)
dacbench_path = (
Path(__file__).resolve().parent
/ "../instance_sets/sgd"
/ self.config.test_set_path
)
keyword = "test_set"
else:
relative_path = (
Path(__file__).resolve().parent / self.config.instance_set_path
)
absolute_path = Path(self.config.instance_set_path)
dacbench_path = (
Path(__file__).resolve().parent
/ "../instance_sets/sgd"
/ self.config.instance_set_path
)
keyword = "instance_set"
if absolute_path.exists():
path = absolute_path
elif relative_path.exists():
path = relative_path
elif dacbench_path.is_file():
path = dacbench_path
else:
raise FileNotFoundError(
f"Instance set file not found at {absolute_path} or {relative_path}"
)
self.config[keyword] = {}
instance_set = pd.read_csv(path)
for index, row in instance_set.iterrows():
if "_" in row["dataset"]:
dataset_info = row["dataset"].split("_")
dataset_name = dataset_info[0]
else:
dataset_name = row["dataset"]
model_constructor = sgd_utils.get_model_constructor(
row["model_type"],
row["model_kwargs"].split("-"),
self.config.local_model_path,
)
optimizer_params = row["optimizer_params"]
if row["optimizer_params"] is None or math.isnan(row["optimizer_params"]):
optimizer_params = {}
instance = SGDInstance(
model=model_constructor,
optimizer_type=row["optimizer"],
optimizer_params=optimizer_params,
dataset_path=Path(__file__).resolve().parent,
dataset_name=dataset_name,
batch_size=row["batch_size"],
fraction_of_dataset=row["fraction_of_dataset"],
train_validation_ratio=row["train_validation_ratio"],
seed=int(row["seed"]),
)
self.config[keyword][index] = instance
[docs]
def get_benchmark(self, instance_set_path=None, seed=0):
"""Get benchmark from the LTO paper.
Parameters
-------
seed : int
Environment seed
Returns:
-------
env : SGDEnv
SGD environment
"""
self.config = objdict(SGD_DEFAULTS.copy())
if instance_set_path is not None:
self.config["instance_set_path"] = instance_set_path
self.config.seed = seed
self.read_instance_set()
return SGDEnv(self.config)