"""Benchmark for SGD."""
from __future__ import annotations
import math
from pathlib import Path
import ConfigSpace as CS  # noqa: N817
import numpy as np
import pandas as pd
from gymnasium import spaces
from torch import nn
from dacbench.abstract_benchmark import AbstractBenchmark, objdict
from dacbench.envs import SGDEnv, SGDInstance
from dacbench.envs.env_utils import sgd_utils
DEFAULT_CFG_SPACE = CS.ConfigurationSpace()
LR = CS.Float(name="learning_rate", bounds=(0.0, 0.05))
# Value used for momentum like adaptation, as adam optimizer has no real momentum;
# "beta1" is changed
MOMENTUM = CS.Float(
    name="momentum", bounds=(0.0, 1.0)
)  # ! Only used, when "use_momentum" var in config true
DEFAULT_CFG_SPACE.add(LR)
DEFAULT_CFG_SPACE.add(MOMENTUM)
def __default_loss_function(**kwargs):
    return nn.NLLLoss(reduction="none", **kwargs)
INFO = {
    "identifier": "LR",
    "name": "Learning Rate Adaption for Neural Networks",
    "reward": "Negative Log Differential Validation Loss",
    "state_description": [
        "Step",
        "Loss",
        "Validation Loss",
        "Crashed",
    ],
    "action_description": ["Learning Rate", "Momentum"],
}
SGD_DEFAULTS = objdict(
    {
        "config_space": DEFAULT_CFG_SPACE,
        "observation_space_class": "Dict",
        "observation_space_type": None,
        "observation_space_args": [
            {
                "step": spaces.Box(low=0, high=np.inf, shape=(1,)),
                "loss": spaces.Box(0, np.inf, shape=(1,)),
                "validationLoss": spaces.Box(low=0, high=np.inf, shape=(1,)),
                "crashed": spaces.Discrete(1),
            }
        ],
        "reward_range": [-(10**9), (10**9)],
        "device": "cpu",
        "use_instance_generator": False,
        "cutoff": 1e2,
        "loss_function": __default_loss_function,
        "loss_function_kwargs": {},
        # "reward_function":,    # Can be set, to replace the default function
        # "state_method":,       # Can be set, to replace the default function
        "use_momentum": False,
        "seed": 0,
        "crash_penalty": -100.0,
        "multi_agent": False,
        "instance_set_path": "sgd_cifar10_variations_train.csv",
        "benchmark_info": INFO,
        "epoch_mode": True,
        "local_model_path": False,
        "dataset_config": None,
    }
)
[docs]
class SGDBenchmark(AbstractBenchmark):
    """Benchmark with default configuration & relevant functions for SGD."""
    def __init__(self, config_path=None, config=None):
        """Initialize SGD Benchmark.
        Parameters
        -------
        config_path : str
            Path to config file (optional)
        """
        super().__init__(config_path, config)
        if not self.config:
            self.config = objdict(SGD_DEFAULTS.copy())
        for key in SGD_DEFAULTS:
            if key not in self.config:
                self.config[key] = SGD_DEFAULTS[key]
[docs]
    def get_environment(self):
        """Return SGDEnv env with current configuration.
        Returns:
        -------
        SGDEnv
            SGD environment
        """
        if "instance_set" not in self.config:
            self.read_instance_set()
        # Read test set if path is specified
        if "test_set" not in self.config and "test_set_path" in self.config:
            self.read_instance_set(test=True)
        env = SGDEnv(self.config)
        for func in self.wrap_funcs:
            env = func(env)
        return env 
[docs]
    def read_instance_set(self, test=False):
        """Read path of instances from config into list."""
        if test:
            relative_path = Path(__file__).resolve().parent / self.config.test_set_path
            absolute_path = Path(self.config.test_set_path)
            dacbench_path = (
                Path(__file__).resolve().parent
                / "../instance_sets/sgd"
                / self.config.test_set_path
            )
            keyword = "test_set"
        else:
            relative_path = (
                Path(__file__).resolve().parent / self.config.instance_set_path
            )
            absolute_path = Path(self.config.instance_set_path)
            dacbench_path = (
                Path(__file__).resolve().parent
                / "../instance_sets/sgd"
                / self.config.instance_set_path
            )
            keyword = "instance_set"
        if absolute_path.exists():
            path = absolute_path
        elif relative_path.exists():
            path = relative_path
        elif dacbench_path.is_file():
            path = dacbench_path
        else:
            raise FileNotFoundError(
                f"Instance set file not found at {absolute_path} or {relative_path}"
            )
        self.config[keyword] = {}
        instance_set = pd.read_csv(path)
        for index, row in instance_set.iterrows():
            if "_" in row["dataset"]:
                dataset_info = row["dataset"].split("_")
                dataset_name = dataset_info[0]
            else:
                dataset_name = row["dataset"]
            model_constructor = sgd_utils.get_model_constructor(
                row["model_type"],
                row["model_kwargs"].split("-"),
                self.config.local_model_path,
            )
            optimizer_params = row["optimizer_params"]
            if row["optimizer_params"] is None or math.isnan(row["optimizer_params"]):
                optimizer_params = {}
            instance = SGDInstance(
                model=model_constructor,
                optimizer_type=row["optimizer"],
                optimizer_params=optimizer_params,
                dataset_path=Path(__file__).resolve().parent,
                dataset_name=dataset_name,
                batch_size=row["batch_size"],
                fraction_of_dataset=row["fraction_of_dataset"],
                train_validation_ratio=row["train_validation_ratio"],
                seed=int(row["seed"]),
            )
            self.config[keyword][index] = instance 
[docs]
    def get_benchmark(self, instance_set_path=None, seed=0):
        """Get benchmark from the LTO paper.
        Parameters
        -------
        seed : int
            Environment seed
        Returns:
        -------
        env : SGDEnv
            SGD environment
        """
        self.config = objdict(SGD_DEFAULTS.copy())
        if instance_set_path is not None:
            self.config["instance_set_path"] = instance_set_path
        self.config.seed = seed
        self.read_instance_set()
        return SGDEnv(self.config)