Skip to content

Cosine lr schedule

mighty.mighty_meta.cosine_lr_schedule #

Cosine LR Schedule with optional warm restarts.

CosineLRSchedule #

CosineLRSchedule(
    initial_lr,
    num_decay_steps,
    min_lr=0,
    restart_every=10000,
    restart_multiplier=1.2,
)

Bases: MightyMetaComponent

Cosine LR Schedule with optional warm restarts.

:param initial_lr: Initial maximal LR :param num_decay_steps: Length of schedule in steps :param min_lr: Minimal LR :param restart_every: Restart frequency :param restart multiplier: Multiplies current learning rate on restart. :return:

Source code in mighty/mighty_meta/cosine_lr_schedule.py
def __init__(
    self,
    initial_lr,
    num_decay_steps,
    min_lr=0,
    restart_every=10000,
    restart_multiplier=1.2,
) -> None:
    """Cosine schedule initialization.

    :param initial_lr: Initial maximal LR
    :param num_decay_steps: Length of schedule in steps
    :param min_lr: Minimal LR
    :param restart_every: Restart frequency
    :param restart multiplier: Multiplies current learning rate on restart.
    :return:
    """
    super().__init__()
    self.restart_every = restart_every
    self.n_restarts = 0
    self.t_mult = restart_multiplier
    self.eta_max = initial_lr
    self.t_max = num_decay_steps
    self.eta_min = min_lr
    self.pre_step_methods = [self.adapt_lr]

adapt_lr #

adapt_lr(metrics)

Adapt LR on step.

:param metrics: Dict of current metrics :return:

Source code in mighty/mighty_meta/cosine_lr_schedule.py
def adapt_lr(self, metrics):
    """Adapt LR on step.

    :param metrics: Dict of current metrics
    :return:
    """
    reset = False
    if self.restart_every > 0:
        if self.n_restarts < np.floor(metrics["step"] / self.restart_every):
            self.n_restarts += 1
            self.eta_max = (
                self.eta_min
                + 0.5
                * (self.eta_max - self.eta_min)
                * (1 + np.cos((metrics["step"] / self.t_max) * np.pi))
                * self.t_mult
            )
            metrics["hp/lr"] = self.eta_max
            reset = True

    if metrics["step"] < self.t_max and not reset:
        metrics["hp/lr"] = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (
            1 + np.cos((metrics["step"] / self.t_max) * np.pi)
        )

post_episode #

post_episode(metrics)

Execute methods at the end of an episode.

:param metrics: Current metrics dict :return:

Source code in mighty/mighty_meta/mighty_component.py
def post_episode(self, metrics):
    """Execute methods at the end of an episode.

    :param metrics: Current metrics dict
    :return:
    """
    for m in self.post_episode_methods:
        m(metrics)

post_step #

post_step(metrics)

Execute methods after a step.

:param metrics: Current metrics dict :return:

Source code in mighty/mighty_meta/mighty_component.py
def post_step(self, metrics):
    """Execute methods after a step.

    :param metrics: Current metrics dict
    :return:
    """
    for m in self.post_step_methods:
        m(metrics)

post_update #

post_update(metrics)

Execute methods after the update.

:param metrics: Current metrics dict :return:

Source code in mighty/mighty_meta/mighty_component.py
def post_update(self, metrics):
    """Execute methods after the update.

    :param metrics: Current metrics dict
    :return:
    """
    for m in self.post_update_methods:
        m(metrics)

pre_episode #

pre_episode(metrics)

Execute methods before an episode.

:param metrics: Current metrics dict :return:

Source code in mighty/mighty_meta/mighty_component.py
def pre_episode(self, metrics):
    """Execute methods before an episode.

    :param metrics: Current metrics dict
    :return:
    """
    for m in self.pre_episode_methods:
        m(metrics)

pre_step #

pre_step(metrics)

Execute methods before a step.

:param metrics: Current metrics dict :return:

Source code in mighty/mighty_meta/mighty_component.py
def pre_step(self, metrics):
    """Execute methods before a step.

    :param metrics: Current metrics dict
    :return:
    """
    for m in self.pre_step_methods:
        m(metrics)

pre_update #

pre_update(metrics)

Execute methods before the update.

:param metrics: Current metrics dict :return:

Source code in mighty/mighty_meta/mighty_component.py
def pre_update(self, metrics):
    """Execute methods before the update.

    :param metrics: Current metrics dict
    :return:
    """
    for m in self.pre_update_methods:
        m(metrics)