"""Environment for sgd with toy functions."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from dacbench import AbstractMADACEnv
if TYPE_CHECKING:
from dacbench.envs.env_utils.toy_functions import AbstractFunction
[docs]
@dataclass
class ToySGDInstance:
"""Toy SGD Instance."""
function: AbstractFunction
[docs]
class ToySGDEnv(AbstractMADACEnv):
"""Optimize toy functions with SGD + Momentum.
Action: [log_learning_rate, log_momentum] (log base 10)
State: Dict with entries remaining_budget, gradient, learning_rate, momentum
Reward: negative log regret of current and true function value
An instance can look as follows:
ID 0
family polynomial
order 2
low -2
high 2
coefficients [ 1.40501053 -0.59899755 1.43337392]
"""
def __init__(self, config):
"""Init env."""
super().__init__(config)
if config["batch_size"]:
self.batch_size = config["batch_size"]
self.velocity = 0
self.gradient = np.zeros(self.batch_size)
self.history = []
self.n_dim = None
self.objective_function = None
self.x_cur = None
self.f_cur = None
self.momentum = 0
self.learning_rate = None
self.rng = np.random.default_rng(self.initial_seed)
self.get_reward = config.get("reward_function", self.get_default_reward)
self.get_state = config.get("state_method", self.get_default_state)
[docs]
def step(
self, action: float | tuple[float, float]
) -> tuple[dict[str, float], float, bool, dict]:
"""Take one step with SGD.
Parameters
----------
action: Tuple[float, Tuple[float, float]]
If scalar, action = (log_learning_rate)
If tuple, action = (log_learning_rate, log_momentum)
Returns:
-------
Tuple[Dict[str, float], float, bool, Dict]
- state : Dict[str, float]
State with entries:
"remaining_budget", "gradient", "learning_rate", "momentum"
- reward : float
- terminated : bool
- truncated : bool
- info : Dict
"""
truncated = super().step_()
info = {}
# parse action
if np.isscalar(action):
log_learning_rate = action
elif len(action) == 2:
log_learning_rate, log_momentum = action
self.momentum = 10**log_momentum
else:
raise ValueError
self.learning_rate = 10**log_learning_rate
# SGD + Momentum update
self.velocity = (
self.momentum * self.velocity + self.learning_rate * self.gradient
)
self.x_cur -= self.velocity
self.gradient = self.objective_function.deriv(self.x_cur)
# current function value
self.f_cur = self.objective_function(self.x_cur)
self.history.append(self.x_cur)
return self.get_state(self), self.get_reward(self), False, truncated, info
[docs]
def reset(self, seed=None, options=None):
"""Reset environment.
Parameters
----------
seed : int
seed
options : dict
options dict (not used)
Returns:
-------
np.array
Environment state
dict
Meta-info
"""
if options is None:
options = {}
super().reset_(seed)
self.velocity = 0
self.gradient = np.zeros(self.batch_size)
self.history = []
self.objective_function = self.instance.function
self.x_cur = self.rng.uniform(-5, 5, size=self.batch_size)
self.f_cur = self.objective_function(self.x_cur)
self.momentum = 0
self.learning_rate = 0
return self.get_state(self), {}
[docs]
def get_default_reward(self, _):
"""Default reward: negative log regret."""
log_regret = np.log10(np.abs(self.objective_function.fmin - self.f_cur))
return -np.mean(log_regret)
[docs]
def get_default_state(self, _):
"""Default state: remaining_budget, gradient, learning_rate, momentum."""
# TODO: add instance description?
remaining_budget = self.n_steps - self.c_step
return {
"remaining_budget": remaining_budget,
"gradient": self.gradient,
"learning_rate": self.learning_rate,
"momentum": self.momentum,
}
[docs]
def render(self, **kwargs):
"""Render progress."""
import matplotlib.pyplot as plt
history = np.array(self.history).flatten()
X = np.linspace(1.05 * np.amin(history), 1.05 * np.amax(history), 100)
Y = self.objective_function(X)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(X, Y, label="True")
ax.plot(
history,
self.objective_function(history),
marker="x",
color="black",
label="Observed",
)
ax.plot(
self.x_cur,
self.objective_function(self.x_cur),
marker="x",
color="red",
label="Current Optimum",
)
ax.legend()
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("instance: " + str(self.instance["coefficients"]))
plt.show()
[docs]
def close(self):
"""Close env."""