"""SGD environment."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from dacbench import AbstractMADACEnv
from dacbench.envs.env_utils import sgd_utils
from dacbench.envs.env_utils.sgd_utils import random_torchvision_loader
[docs]
def test(
model,
loss_function,
loader,
batch_size,
batch_percentage: float = 1.0,
device="cpu",
):
"""Evaluate given `model` on `loss_function`.
Percentage defines how much percentage of the data shall be used.
If nothing given the whole data is used.
Returns:
test_losses: Batch validation loss per data point
"""
nmb_sets = batch_percentage * (len(loader.dataset) / batch_size)
model.eval()
test_losses = []
test_accuracies = []
i = 0
with torch.no_grad():
for data, target in loader:
d_data, d_target = data.to(device), target.to(device)
output = model(d_data)
_, preds = output.max(dim=1)
test_losses.append(loss_function(output, d_target))
test_accuracies.append(torch.sum(preds == target) / len(target))
i += 1
if i >= nmb_sets:
break
return torch.cat(test_losses), torch.tensor(test_accuracies)
[docs]
def forward_backward(model, loss_function, loader, device="cpu"):
"""Do a forward and a backward pass for given `model` for `loss_function`.
Returns:
loss: Mini batch training loss per data point
"""
model.train()
(data, target) = next(iter(loader))
data, target = data.to(device), target.to(device)
output = model(data)
loss = loss_function(output, target)
loss.mean().backward()
return loss.mean().detach()
[docs]
def run_epoch(model, loss_function, loader, optimizer, device="cpu"):
"""Run a single epoch of training for given `model` with `loss_function`."""
last_loss = None
running_loss = 0
model.train()
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_function(output, target)
loss.mean().backward()
optimizer.step()
last_loss = loss
running_loss += last_loss.mean().item()
return last_loss.mean().detach(), running_loss / len(loader)
[docs]
@dataclass
class SGDInstance:
"""SGD Instance."""
model: torch.nn.Module
optimizer_type: torch.optim.Optimizer
optimizer_params: dict
dataset_path: str
dataset_name: str
batch_size: int
fraction_of_dataset: float
train_validation_ratio: float
seed: int
[docs]
class SGDEnv(AbstractMADACEnv):
"""The SGD DAC Environment implements the problem of dynamically configuring
the learning rate hyperparameter of a neural network optimizer
(more specifically, torch.optim.AdamW) for a supervised learning task.
While training, the model is evaluated after every epoch.
Actions correspond to learning rate values in [0,+inf[
For observation space check `observation_space` method docstring.
For instance space check the `SGDInstance` class docstring
Reward:
negative loss of model on test_loader of the instance if done
crash_penalty of the instance if crashed
0 otherwise
"""
metadata = {"render_modes": ["human"]} # noqa: RUF012
def __init__(self, config):
"""Init env."""
super().__init__(config)
torch.manual_seed(seed=config.get("seed", 0))
self.epoch_mode = config.get("epoch_mode", True)
self.device = config.get("device")
self.learning_rate = None
self.crash_penalty = config.get("crash_penalty")
self.loss_function = config.loss_function(**config.loss_function_kwargs)
self.use_generator = config.get("use_instance_generator")
# Use default reward function, if no specific function is given
self.get_reward = config.get("reward_function", self.get_default_reward)
# Use default state function, if no specific function is given
self.get_state = config.get("state_method", self.get_default_state)
[docs]
def step(self, action: float):
"""Update the parameters of the neural network using the given learning rate lr,
in the direction specified by AdamW, and if not done (crashed/cutoff reached),
performs another forward/backward pass (update only in the next step).
"""
truncated = super().step_()
info = {}
if isinstance(action, float):
action = [action]
for g in self.optimizer.param_groups:
g["lr"] = action[0]
if len(action) > 1:
g["betas"] = (action[1], 0.999)
if self.epoch_mode:
self.loss, self.average_loss = run_epoch(
self.model,
self.loss_function,
self.train_loader,
self.optimizer,
self.device,
)
else:
train_args = [
self.model,
self.loss_function,
self.train_loader,
self.device,
]
self.optimizer.step()
self.optimizer.zero_grad()
self.loss = forward_backward(*train_args)
crashed = (
not torch.isfinite(self.loss).any()
or not torch.isfinite(
torch.nn.utils.parameters_to_vector(self.model.parameters())
).any()
)
self.loss = self.loss.numpy().item()
if crashed:
self._done = True
return (
self.get_state(self),
self.crash_penalty,
False,
True,
info,
)
self._done = truncated
if (
self.n_steps % len(self.train_loader) == 0 or self._done
): # Calculate validation loss at the end of an epoch
batch_percentage = 1.0
else:
batch_percentage = 0.1
val_args = [
self.model,
self.loss_function,
self.validation_loader,
self.instance.batch_size,
batch_percentage,
self.device,
]
validation_loss, validation_accuracy = test(*val_args)
self.validation_loss = validation_loss.mean().detach().numpy()
self.validation_accuracy = validation_accuracy.mean().detach().numpy()
if (
self.min_validation_loss is None
or self.validation_loss <= self.min_validation_loss
):
self.min_validation_loss = self.validation_loss
if self._done:
val_args = [
self.model,
self.loss_function,
self.test_loader,
self.instance.batch_size,
1.0,
self.device,
]
self.test_losses, self.test_accuracies = test(*val_args)
reward = self.get_reward(self)
return self.get_state(self), reward, False, truncated, info
[docs]
def reset(self, seed=None, options=None):
"""Initialize the neural network, data loaders, etc. for given/random next task.
Also perform a single forward/backward pass,
not yet updating the neural network parameters.
"""
if options is None:
options = {}
super().reset_(seed)
rng = np.random.RandomState(self.initial_seed)
# Get loaders for instance
self.datasets, loaders = random_torchvision_loader(
self.instance.seed,
self.instance.dataset_path,
self.instance.dataset_name,
self.instance.batch_size,
self.instance.fraction_of_dataset,
self.instance.train_validation_ratio,
dataset_config=self.config.dataset_config,
)
self.train_loader, self.validation_loader, self.test_loader = loaders
if self.use_generator:
(
self.model,
self.optimizer_params,
self.instance.batch_size,
self.crash_penalty,
) = sgd_utils.random_instance(rng, self.datasets)
else:
self.model = self.instance.model()
self.optimizer_params = self.instance.optimizer_params
self.learning_rate = None
self.optimizer_type = self.instance.optimizer_type
self.info = {}
self._done = False
self.model.to(self.device)
self.optimizer: torch.optim.Optimizer = torch.optim.AdamW(
**self.instance.optimizer_params, params=self.model.parameters()
)
self.loss = 0
self.test_losses = None
self.validation_loss = 0
self.validation_accuracy = 0
self.min_validation_loss = None
if self.epoch_mode:
self.average_loss = 0
return self.get_state(self), {}
[docs]
def get_default_reward(self, _) -> float:
"""The default reward function.
Args:
_ (_type_): Empty parameter, which can be used when overriding
Returns:
float: The calculated reward
"""
if self.test_losses is not None:
reward = self.test_losses.sum().item() / len(self.test_loader.dataset)
else:
reward = 0.0
return -reward
[docs]
def get_default_state(self, _) -> dict:
"""Default state function.
Args:
_ (_type_): Empty parameter, which can be used when overriding
Returns:
dict: The current state
"""
state = {
"step": self.c_step,
"loss": self.loss,
"validation_loss": self.validation_loss,
"validation_accuracy": self.validation_accuracy,
"done": self._done,
}
if self.epoch_mode:
state["average_loss"] = self.average_loss
if self._done and self.test_losses is not None:
state["test_losses"] = self.test_losses
state["test_accuracies"] = self.test_accuracies
return state
[docs]
def render(self, mode="human"):
"""Render progress."""
if mode == "human":
epoch = 1 + self.c_step // len(self.train_loader)
epoch_cutoff = self.n_steps // len(self.train_loader)
batch = 1 + self.c_step % len(self.train_loader)
print(
f"prev_lr {self.optimizer.param_groups[0]['lr'] if self.n_steps > 0 else None}, " # noqa: E501
f"epoch {epoch}/{epoch_cutoff}, "
f"batch {batch}/{len(self.train_loader)}, "
f"batch_loss {self.loss}, "
f"val_loss {self.validation_loss}"
)
else:
raise NotImplementedError