:mod:`heat.optim.utils` ======================= .. py:module:: heat.optim.utils .. autoapi-nested-parse:: Utility functions for the heat optimizers Module Contents --------------- .. py:class:: DetectMetricPlateau(mode: Optional[str] = 'min', patience: Optional[int] = 10, threshold: Optional[float] = 0.0001, threshold_mode: Optional[str] = 'rel', cooldown: Optional[int] = 0) Bases: :class:`object` Determine if a when a metric has stopped improving. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced. Adapted from `torch.optim.lr_scheduler.ReduceLROnPlateau `_. :param mode: str, optional One of `min`, `max`. In `min` mode, the quantity monitored is determined to have plateaued when it stops decreasing. In `max` mode, the quantity monitored is determined to have plateaued when it stops decreasing.\n Default: 'min'. :param patience: int, optional Number of epochs to wait before determining if there is a plateau For example, if `patience = 2`, then we will ignore the first 2 epochs with no improvement, and will only determine if there is a plateau after the 3rd epoch if the loss still hasn't improved then.\n Default: 10. :param threshold: float, optional Threshold for measuring the new optimum to only focus on significant changes.\n Default: 1e-4. :param threshold_mode: str, optional One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * ( 1 + threshold ) in 'max' mode or best * ( 1 - threshold ) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode.\n Default: 'rel'. :param cooldown: int, optional Number of epochs to wait before resuming normal operation after lr has been reduced.\n Default: 0. .. attribute:: patience :annotation: = 10 .. attribute:: cooldown :annotation: = 0 .. attribute:: cooldown_counter :annotation: = 0 .. attribute:: mode :annotation: = 'min' .. attribute:: threshold :annotation: = 0.0001 .. attribute:: threshold_mode :annotation: = 'rel' .. attribute:: best :annotation: = None .. attribute:: num_bad_epochs :annotation: = None .. attribute:: mode_worse :annotation: = None .. attribute:: last_epoch :annotation: = 0 .. role:: raw-html(raw) :format: html .. method:: get_state() -> Dict Get a dictionary of the class parameters. This is useful for checkpointing. .. method:: set_state(dic: Dict) -> None Load a dictionary with the status of the class. Typically used in checkpointing. :param dic: contains the values to be set as the class parameters :type dic: Dictionary .. method:: reset() -> None Resets num_bad_epochs counter and cooldown counter. .. method:: test_if_improving(metrics: torch.Tensor) -> bool Test if the metric/s is/are improving. If the metrics are better than the adjusted best value, they are set as the best for future testing. :param metrics: the metrics to test :type metrics: torch.Tensor :rtype: True if the metrics are better than the best, False otherwise .. method:: is_better(a: float, best: float) -> bool Test if the given value is better than the current best value. The best value is adjusted with the threshold :param a: the metric value :type a: float :param best: the current best value for the metric :type best: float :rtype: boolean indicating if the metric is improving .. method:: _init_is_better(mode: str, threshold: float, threshold_mode: str) -> None Initialize the is_better function for comparisons later