Source code for heat.optim

"""
Optimizer module.

It contains data parallel specific optimizers and learning rate schedulers. It also includes all of the
optimizers and learning rate schedulers in the torch namespace
"""

from .lr_scheduler import *
from . import utils

import sys
import torch
import unittest


if sys.version_info.minor >= 7:
    from .dp_optimizer import *

[docs] def __getattr__(name): """ When a function is called for the heat.optim module it will attempt to run the heat optimizer with that name, then, if there is no such heat optimizer, it will attempt to get the torch optimizer of that name. """ # this will call the Heat optimizers if available, # otherwise, it falls back to call a torch optimizer if name in dp_optimizer.__all__: return dp_optimizer.__getattribute__(name) try: return torch.optim.__getattribute__(name) except AttributeError: try: unittest.__getattribute__(name) except AttributeError: if name is not None: raise AttributeError(f"module {name} not implemented in torch.optim")
else: from . import dp_optimizer from . import tests class _Wrapper36(object): """ Wrapper to handle the dynamic calling of torch.optim modules in the heat namespace """ def __init__(self, wrapped): # noqa: D107 self.wrapped = wrapped def __getattr__(self, name): """ When a function is called for the heat.optim module it will attempt to run the heat optimizer with that name, then, if there is no such heat optimizer, it will attempt to get the torch optimizer of that name. """ # this will call the Heat optimizers if available, # otherwise, it falls back to call a torch optimizer if name in dp_optimizer.__all__: return dp_optimizer.__getattribute__(name) elif name == "tests": return tests try: return torch.optim.__getattribute__(name) except AttributeError: try: unittest.__getattribute__(name) except AttributeError: if name is not None: raise AttributeError(f"module '{name}' not implemented in torch or heat") sys.modules[__name__] = _Wrapper36(sys.modules[__name__])