Source code for heat.optim.utils

"""
Utility functions for the heat optimizers
"""

import math
import torch

from typing import Optional, Dict


__all__ = ["DetectMetricPlateau"]


[docs] class DetectMetricPlateau(object): r""" Determine if a when a metric has stopped improving. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced. Adapted from `torch.optim.lr_scheduler.ReduceLROnPlateau <https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau>`_. Args: mode: str, optional One of `min`, `max`. In `min` mode, the quantity monitored is determined to have plateaued when it stops decreasing. In `max` mode, the quantity monitored is determined to have plateaued when it stops decreasing.\n Default: 'min'. patience: int, optional Number of epochs to wait before determining if there is a plateau For example, if `patience = 2`, then we will ignore the first 2 epochs with no improvement, and will only determine if there is a plateau after the 3rd epoch if the loss still hasn't improved then.\n Default: 10. threshold: float, optional Threshold for measuring the new optimum to only focus on significant changes.\n Default: 1e-4. threshold_mode: str, optional One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * ( 1 + threshold ) in 'max' mode or best * ( 1 - threshold ) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode.\n Default: 'rel'. cooldown: int, optional Number of epochs to wait before resuming normal operation after lr has been reduced.\n Default: 0. """ def __init__( self, mode: Optional[str] = "min", patience: Optional[int] = 10, threshold: Optional[float] = 1e-4, threshold_mode: Optional[str] = "rel", cooldown: Optional[int] = 0, ): # noqa: D107 self.patience = patience self.cooldown = cooldown self.cooldown_counter = 0 self.mode = mode self.threshold = threshold self.threshold_mode = threshold_mode self.best = None self.num_bad_epochs = None self.mode_worse = None # the worse value for the chosen mode self.last_epoch = 0 self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode) self.reset()
[docs] def get_state(self) -> Dict: """ Get a dictionary of the class parameters. This is useful for checkpointing. """ return { "patience": self.patience, "cooldown": self.cooldown, "cooldown_counter": self.cooldown_counter, "mode": self.mode, "threshold": self.threshold, "threshold_mode": self.threshold_mode, "best": self.best, "num_bad_epochs": self.num_bad_epochs, "mode_worse": self.mode_worse, "last_epoch": self.last_epoch, }
[docs] def set_state(self, dic: Dict) -> None: """ Load a dictionary with the status of the class. Typically used in checkpointing. Parameters ---------- dic: Dictionary contains the values to be set as the class parameters """ self.patience = dic["patience"] self.cooldown = dic["cooldown"] self.cooldown_counter = dic["cooldown_counter"] self.mode = dic["mode"] self.threshold = dic["threshold"] self.threshold_mode = dic["threshold_mode"] self.best = dic["best"] self.num_bad_epochs = dic["num_bad_epochs"] self.mode_worse = dic["mode_worse"] self.last_epoch = dic["last_epoch"]
[docs] def reset(self) -> None: """ Resets num_bad_epochs counter and cooldown counter. """ self.best = self.mode_worse self.cooldown_counter = 0 self.num_bad_epochs = 0
[docs] def test_if_improving(self, metrics: torch.Tensor) -> bool: """ Test if the metric/s is/are improving. If the metrics are better than the adjusted best value, they are set as the best for future testing. Parameters ---------- metrics: torch.Tensor the metrics to test Returns ------- True if the metrics are better than the best, False otherwise """ # convert `metrics` to float, in case it's a zero-dim Tensor current = float(metrics) epoch = self.last_epoch + 1 self.last_epoch = epoch if self.is_better(current, self.best): self.best = current self.num_bad_epochs = 0 else: self.num_bad_epochs += 1 if self.in_cooldown: self.cooldown_counter -= 1 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown if self.num_bad_epochs > self.patience: self.cooldown_counter = self.cooldown self.num_bad_epochs = 0 return True return False
@property def in_cooldown(self) -> bool: """ Test if the class is in the cool down period """ return self.cooldown_counter > 0
[docs] def is_better(self, a: float, best: float) -> bool: """ Test if the given value is better than the current best value. The best value is adjusted with the threshold Parameters ---------- a: float the metric value best: float the current best value for the metric Returns ------- boolean indicating if the metric is improving """ if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold comp = best * rel_epsilon if best >= 0 else best * (1 + self.threshold) return a < comp elif self.mode == "min" and self.threshold_mode == "abs": return a < best - self.threshold elif self.mode == "max" and self.threshold_mode == "rel": rel_epsilon = self.threshold + 1.0 return a > best * rel_epsilon else: # mode == 'max' and epsilon_mode == 'abs': return a > best + self.threshold
[docs] def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None: """ Initialize the is_better function for comparisons later """ if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: raise ValueError("threshold mode " + threshold_mode + " is unknown!") if mode == "min": self.mode_worse = math.inf else: # mode == 'max': self.mode_worse = -math.inf self.mode = mode self.threshold = threshold self.threshold_mode = threshold_mode