The SGD Deep Learning Benchmark

Task: control the learning rate in deep learning
Cost: log differential validation loss
Number of hyperparameters to control: one or two floats
State Information: predictive change variance, predictive change variance, loss variance, loss variance uncertainty, current learning rate, training loss, validation loss, step, alignment, crashed
Noise Level: fairly large
Instance space: dataset, network architecture, optimizer

Built on top of PyTorch, this benchmark allows for dynamic learning rate control in deep learning. At each step until the cutoff, i.e. after each epoch, the DAC controller provides a new learning rate value to the network. Success is measured by decreasing validation loss.

This is a very flexible benchmark, as in principle all kinds of classification datasets and PyTorch compatible architectures can be included in training. The underlying task is not easy, however, so we recommend starting with small networks and datasets and building up to harder tasks.

Benchmark for SGD.

class dacbench.benchmarks.sgd_benchmark.SGDBenchmark(config_path=None, config=None)[source]

Bases: AbstractBenchmark

Benchmark with default configuration & relevant functions for SGD.

get_benchmark(instance_set_path=None, seed=0)[source]

Get benchmark from the LTO paper.

Parameters:
  • seed (int) – Environment seed

  • Returns

  • -------

  • env (SGDEnv) – SGD environment

get_environment()[source]

Return SGDEnv env with current configuration.

Returns:

SGDEnv

SGD environment

read_instance_set(test=False)[source]

Read path of instances from config into list.

SGD environment.

class dacbench.envs.sgd.SGDEnv(config)[source]

Bases: AbstractMADACEnv

The SGD DAC Environment implements the problem of dynamically configuring the learning rate hyperparameter of a neural network optimizer (more specifically, torch.optim.AdamW) for a supervised learning task. While training, the model is evaluated after every epoch.

Actions correspond to learning rate values in [0,+inf[ For observation space check observation_space method docstring. For instance space check the SGDInstance class docstring Reward:

negative loss of model on test_loader of the instance if done crash_penalty of the instance if crashed 0 otherwise

get_default_reward(_) float[source]

The default reward function.

Parameters:

_ (_type_) – Empty parameter, which can be used when overriding

Returns:

The calculated reward

Return type:

float

get_default_state(_) dict[source]

Default state function.

Parameters:

_ (_type_) – Empty parameter, which can be used when overriding

Returns:

The current state

Return type:

dict

render(mode='human')[source]

Render progress.

reset(seed=None, options=None)[source]

Initialize the neural network, data loaders, etc. for given/random next task. Also perform a single forward/backward pass, not yet updating the neural network parameters.

step(action: float)[source]

Update the parameters of the neural network using the given learning rate lr, in the direction specified by AdamW, and if not done (crashed/cutoff reached), performs another forward/backward pass (update only in the next step).

class dacbench.envs.sgd.SGDInstance(model: Module, optimizer_type: Optimizer, optimizer_params: dict, dataset_path: str, dataset_name: str, batch_size: int, fraction_of_dataset: float, train_validation_ratio: float, seed: int)[source]

Bases: object

SGD Instance.

dacbench.envs.sgd.forward_backward(model, loss_function, loader, device='cpu')[source]

Do a forward and a backward pass for given model for loss_function.

Returns:

Mini batch training loss per data point

Return type:

loss

dacbench.envs.sgd.run_epoch(model, loss_function, loader, optimizer, device='cpu')[source]

Run a single epoch of training for given model with loss_function.

dacbench.envs.sgd.test(model, loss_function, loader, batch_size, batch_percentage: float = 1.0, device='cpu')[source]

Evaluate given model on loss_function.

Percentage defines how much percentage of the data shall be used. If nothing given the whole data is used.

Returns:

Batch validation loss per data point

Return type:

test_losses