diff --git a/hfdocs/source/reference/optimizers.mdx b/hfdocs/source/reference/optimizers.mdx index 637e7f0a74..0e124192d2 100644 --- a/hfdocs/source/reference/optimizers.mdx +++ b/hfdocs/source/reference/optimizers.mdx @@ -6,22 +6,29 @@ This page contains the API reference documentation for learning rate optimizers ### Factory functions -[[autodoc]] timm.optim.optim_factory.create_optimizer -[[autodoc]] timm.optim.optim_factory.create_optimizer_v2 +[[autodoc]] timm.optim.create_optimizer_v2 +[[autodoc]] timm.optim.list_optimizers +[[autodoc]] timm.optim.get_optimizer_class ### Optimizer Classes [[autodoc]] timm.optim.adabelief.AdaBelief [[autodoc]] timm.optim.adafactor.Adafactor +[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision [[autodoc]] timm.optim.adahessian.Adahessian [[autodoc]] timm.optim.adamp.AdamP [[autodoc]] timm.optim.adamw.AdamW +[[autodoc]] timm.optim.adan.Adan +[[autodoc]] timm.optim.adopt.Adopt [[autodoc]] timm.optim.lamb.Lamb [[autodoc]] timm.optim.lars.Lars +[[autodoc]] timm.optim.lion.Lion [[autodoc]] timm.optim.lookahead.Lookahead [[autodoc]] timm.optim.madgrad.MADGRAD [[autodoc]] timm.optim.nadam.Nadam +[[autodoc]] timm.optim.nadamw.NAdamW [[autodoc]] timm.optim.nvnovograd.NvNovoGrad [[autodoc]] timm.optim.radam.RAdam [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF [[autodoc]] timm.optim.sgdp.SGDP +[[autodoc]] timm.optim.sgdw.SGDW \ No newline at end of file diff --git a/tests/test_optim.py b/tests/test_optim.py index 66aaadbf95..d70ec98d34 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -12,11 +12,10 @@ from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter -from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay +from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo +from timm.optim import param_groups_layer_decay, param_groups_weight_decay from timm.scheduler import PlateauLRScheduler -from timm.optim import create_optimizer_v2 - import importlib import os @@ -177,7 +176,7 @@ def _test_basic_cases(constructor, scheduler_constructors=None): ) -def _test_model(optimizer, params, device=torch.device('cpu')): +def _test_model(optimizer, params, device=torch.device('cpu'), after_step=0): weight = torch.tensor( [[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]], device=device, requires_grad=True) @@ -208,7 +207,8 @@ def _test_model(optimizer, params, device=torch.device('cpu')): loss = output.sum() loss.backward() loss = loss.item() - assert loss < prev_loss + if i > after_step: + assert loss < prev_loss prev_loss = loss optimizer.step() @@ -237,31 +237,44 @@ def _test_rosenbrock(constructor, scheduler_constructors=None): solution = torch.tensor([1, 1]) initial_dist = params.clone().detach().dist(solution) - def eval(params, w): + + def get_grad(_param, _sparse_grad, _w): + grad = drosenbrock(params.clone().detach()) + # Depending on w, provide only the x or y gradient + if _sparse_grad: + if _w: + i = torch.tensor([[0, 0]], dtype=torch.int64) + x = grad[0] + v = torch.tensor([x / 4.0, x - x / 4.0]) + else: + i = torch.tensor([[1, 1]], dtype=torch.int64) + y = grad[1] + v = torch.tensor([y - y / 4.0, y / 4.0]) + grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype) + else: + if _w: + grad_out = torch.tensor([grad[0], 0], dtype=_param.dtype) + else: + grad_out = torch.tensor([0, grad[1]], dtype=_param.dtype) + return grad_out + + + def eval(_param, _sparse_grad, _w): # Depending on w, provide only the x or y gradient optimizer.zero_grad() - loss = rosenbrock(params) + loss = rosenbrock(_param) loss.backward() - grad = drosenbrock(params.clone().detach()) - # NB: We torture test the optimizer by returning an - # uncoalesced sparse tensor - if w: - i = torch.LongTensor([[0, 0]]) - x = grad[0] - v = torch.tensor([x / 4., x - x / 4.]) - else: - i = torch.LongTensor([[1, 1]]) - y = grad[1] - v = torch.tensor([y - y / 4., y / 4.]) - x = torch.sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype) + + grad_out = get_grad(_param, _sparse_grad, _w) with torch.no_grad(): - params.grad = x.to_dense() + _param.grad = grad_out.to_dense() + return loss for i in range(2000): # Do cyclic coordinate descent w = i % 2 - optimizer.step(functools.partial(eval, params, w)) + optimizer.step(functools.partial(eval, params, True, w)) for scheduler in schedulers: if isinstance(scheduler, PlateauLRScheduler): scheduler.step(rosenbrock(params)) @@ -279,29 +292,40 @@ def _build_params_dict_single(weight, bias, **kwargs): return [dict(params=bias, **kwargs)] +@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) +def test_optim_factory(optimizer): + assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer) + + opt_info = get_optimizer_info(optimizer) + assert isinstance(opt_info, OptimInfo) + + if not opt_info.second_order: # basic tests don't support second order right now + # test basic cases that don't need specific tuning via factory test + _test_basic_cases( + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict(weight, bias, lr=1e-2), + optimizer, + lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=1e-2), + optimizer, + lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=1e-2), optimizer) + ) + + #@pytest.mark.parametrize('optimizer', ['sgd', 'momentum']) # FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts @pytest.mark.parametrize('optimizer', ['sgd']) def test_sgd(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), optimizer) - ) # _test_basic_cases( # lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3), # [lambda opt: StepLR(opt, gamma=0.9, step_size=10)] @@ -342,50 +366,25 @@ def test_sgd(optimizer): _test_model(optimizer, dict(lr=1e-3)) -@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax']) +@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw']) def test_adam(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) ) _test_model(optimizer, dict(lr=5e-2)) +@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) +def test_adopt(optimizer): + # FIXME rosenbrock is not passing for ADOPT + # _test_rosenbrock( + # lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + # ) + _test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT + + @pytest.mark.parametrize('optimizer', ['adabelief']) def test_adabelief(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) ) @@ -397,21 +396,6 @@ def test_adabelief(optimizer): @pytest.mark.parametrize('optimizer', ['radam', 'radabelief']) def test_rectified(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -420,25 +404,6 @@ def test_rectified(optimizer): @pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad']) def test_adaother(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) ) @@ -448,26 +413,8 @@ def test_adaother(optimizer): _test_model(optimizer, dict(lr=5e-2)) -@pytest.mark.parametrize('optimizer', ['adafactor']) +@pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv']) def test_adafactor(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer) - ) _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) ) @@ -479,25 +426,6 @@ def test_adafactor(optimizer): @pytest.mark.parametrize('optimizer', ['lamb', 'lambc']) def test_lamb(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -506,25 +434,6 @@ def test_lamb(optimizer): @pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc']) def test_lars(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -533,25 +442,6 @@ def test_lars(optimizer): @pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw']) def test_madgrad(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) ) @@ -560,25 +450,6 @@ def test_madgrad(optimizer): @pytest.mark.parametrize('optimizer', ['novograd']) def test_novograd(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -587,25 +458,6 @@ def test_novograd(optimizer): @pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf']) def test_rmsprop(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) ) @@ -614,25 +466,6 @@ def test_rmsprop(optimizer): @pytest.mark.parametrize('optimizer', ['adamp']) def test_adamp(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) ) @@ -641,25 +474,6 @@ def test_adamp(optimizer): @pytest.mark.parametrize('optimizer', ['sgdp']) def test_sgdp(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -668,25 +482,6 @@ def test_sgdp(optimizer): @pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum']) def test_lookahead_sgd(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -694,25 +489,6 @@ def test_lookahead_sgd(optimizer): @pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam']) def test_lookahead_adam(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) ) @@ -720,25 +496,6 @@ def test_lookahead_adam(optimizer): @pytest.mark.parametrize('optimizer', ['lookahead_radam']) def test_lookahead_radam(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) ) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index d2eb67bab5..552585c91b 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -1,11 +1,14 @@ from .adabelief import AdaBelief from .adafactor import Adafactor +from .adafactor_bv import AdafactorBigVision from .adahessian import Adahessian from .adamp import AdamP from .adamw import AdamW from .adan import Adan +from .adopt import Adopt from .lamb import Lamb from .lars import Lars +from .lion import Lion from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam @@ -13,5 +16,7 @@ from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .lion import Lion -from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs + +from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ + create_optimizer_v2, create_optimizer, optimizer_kwargs +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py new file mode 100644 index 0000000000..9e06e1a50a --- /dev/null +++ b/timm/optim/_optim_factory.py @@ -0,0 +1,960 @@ +""" Optimizer Factory w/ custom Weight Decay & Layer Decay support + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator +from fnmatch import fnmatch +import importlib + +import torch +import torch.nn as nn +import torch.optim as optim + +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adafactor_bv import AdafactorBigVision +from .adahessian import Adahessian +from .adamp import AdamP +from .adan import Adan +from .adopt import Adopt +from .lamb import Lamb +from .lars import Lars +from .lion import Lion +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nadamw import NAdamW +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .sgdw import SGDW + +_logger = logging.getLogger(__name__) + +# Type variables +T = TypeVar('T') +Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]] +OptimType = TypeVar('OptimType', bound='optim.Optimizer') + + +def _import_class(class_string: str) -> Type: + """Dynamically import a class from a string.""" + try: + module_name, class_name = class_string.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import {class_string}: {e}") + + +class OptimizerCallable(Protocol): + """Protocol for optimizer constructor signatures.""" + + def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ... + + +@dataclass(frozen=True) +class OptimInfo: + """Immutable configuration for an optimizer. + + Attributes: + name: Unique identifier for the optimizer + opt_class: The optimizer class + description: Brief description of the optimizer's characteristics and behavior + has_eps: Whether the optimizer accepts epsilon parameter + has_momentum: Whether the optimizer accepts momentum parameter + has_betas: Whether the optimizer accepts a tuple of beta parameters + num_betas: number of betas in tuple (valid IFF has_betas = True) + defaults: Optional default parameters for the optimizer + """ + name: str + opt_class: Union[str, Type[optim.Optimizer]] + description: str = '' + has_eps: bool = True + has_momentum: bool = False + has_betas: bool = False + num_betas: int = 2 + second_order: bool = False + defaults: Optional[Dict[str, Any]] = None + + +class OptimizerRegistry: + """Registry managing optimizer configurations and instantiation. + + This class provides a central registry for optimizer configurations and handles + their instantiation with appropriate parameter groups and settings. + """ + + def __init__(self) -> None: + self._optimizers: Dict[str, OptimInfo] = {} + self._foreach_defaults: Set[str] = {'lion'} + + def register(self, info: OptimInfo) -> None: + """Register an optimizer configuration. + + Args: + info: The OptimInfo configuration containing name, type and description + """ + name = info.name.lower() + if name in self._optimizers: + _logger.warning(f'Optimizer {name} already registered, overwriting') + self._optimizers[name] = info + + def register_alias(self, alias: str, target: str) -> None: + """Register an alias for an existing optimizer. + + Args: + alias: The alias name + target: The target optimizer name + + Raises: + KeyError: If target optimizer doesn't exist + """ + target = target.lower() + if target not in self._optimizers: + raise KeyError(f'Cannot create alias for non-existent optimizer {target}') + self._optimizers[alias.lower()] = self._optimizers[target] + + def register_foreach_default(self, name: str) -> None: + """Register an optimizer as defaulting to foreach=True.""" + self._foreach_defaults.add(name.lower()) + + def list_optimizers( + self, + filter: Union[str, List[str]] = '', + exclude_filters: Optional[List[str]] = None, + with_description: bool = False + ) -> List[Union[str, Tuple[str, str]]]: + """List available optimizer names, optionally filtered. + + Args: + filter: Wildcard style filter string (e.g., 'adam*') + exclude_filters: Optional list of wildcard patterns to exclude + with_description: If True, return tuples of (name, description) + + Returns: + List of either optimizer names or (name, description) tuples + """ + names = sorted(self._optimizers.keys()) + + if filter: + if isinstance(filter, str): + filters = [filter] + else: + filters = filter + filtered_names = set() + for f in filters: + filtered_names.update(n for n in names if fnmatch(n, f)) + names = sorted(filtered_names) + + if exclude_filters: + for exclude_filter in exclude_filters: + names = [n for n in names if not fnmatch(n, exclude_filter)] + + if with_description: + return [(name, self._optimizers[name].description) for name in names] + + return names + + def get_optimizer_info(self, name: str) -> OptimInfo: + """Get the OptimInfo for an optimizer. + + Args: + name: Name of the optimizer + + Returns: + OptimInfo configuration + + Raises: + ValueError: If optimizer is not found + """ + name = name.lower() + if name not in self._optimizers: + raise ValueError(f'Optimizer {name} not found in registry') + return self._optimizers[name] + + def get_optimizer_class( + self, + name_or_info: Union[str, OptimInfo], + bind_defaults: bool = True, + ) -> Union[Type[optim.Optimizer], OptimizerCallable]: + """Get the optimizer class with any default arguments applied. + + This allows direct instantiation of optimizers with their default configs + without going through the full factory. + + Args: + name_or_info: Name of the optimizer + bind_defaults: Bind default arguments to optimizer class via `partial` before returning + + Returns: + Optimizer class or partial with defaults applied + + Raises: + ValueError: If optimizer not found + """ + if isinstance(name_or_info, str): + opt_info = self.get_optimizer_info(name_or_info) + else: + assert isinstance(name_or_info, OptimInfo) + opt_info = name_or_info + + if isinstance(opt_info.opt_class, str): + # Special handling for APEX and BNB optimizers + if opt_info.opt_class.startswith('apex.'): + assert torch.cuda.is_available(), 'CUDA required for APEX optimizers' + try: + opt_class = _import_class(opt_info.opt_class) + except ImportError as e: + raise ImportError('APEX optimizers require apex to be installed') from e + elif opt_info.opt_class.startswith('bitsandbytes.'): + assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers' + try: + opt_class = _import_class(opt_info.opt_class) + except ImportError as e: + raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e + else: + opt_class = _import_class(opt_info.opt_class) + else: + opt_class = opt_info.opt_class + + # Return class or partial with defaults + if bind_defaults and opt_info.defaults: + opt_class = partial(opt_class, **opt_info.defaults) + + return opt_class + + def create_optimizer( + self, + model_or_params: Union[nn.Module, Params], + opt: str, + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + foreach: Optional[bool] = None, + weight_decay_exclude_1d: bool = True, + layer_decay: Optional[float] = None, + param_group_fn: Optional[Callable[[nn.Module], Params]] = None, + **kwargs: Any, + ) -> optim.Optimizer: + """Create an optimizer instance. + + Args: + model_or_params: Model or parameters to optimize + opt: Name of optimizer to create + lr: Learning rate + weight_decay: Weight decay factor + momentum: Momentum factor for applicable optimizers + foreach: Enable/disable foreach operation + weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine) + layer_decay: Layer-wise learning rate decay + param_group_fn: Optional custom parameter grouping function + **kwargs: Additional optimizer-specific arguments + + Returns: + Configured optimizer instance + + Raises: + ValueError: If optimizer not found or configuration invalid + """ + + # Get parameters to optimize + if isinstance(model_or_params, nn.Module): + # Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied + no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())() + + if param_group_fn: + # run custom fn to generate param groups from nn.Module + params = param_group_fn(model_or_params) + elif layer_decay is not None: + params = param_groups_layer_decay( + model_or_params, + weight_decay=weight_decay, + layer_decay=layer_decay, + no_weight_decay_list=no_weight_decay, + weight_decay_exclude_1d=weight_decay_exclude_1d, + ) + weight_decay = 0. + elif weight_decay and weight_decay_exclude_1d: + params = param_groups_weight_decay( + model_or_params, + weight_decay=weight_decay, + no_weight_decay_list=no_weight_decay, + ) + weight_decay = 0. + else: + params = model_or_params.parameters() + else: + # pass parameters / parameter groups through to optimizer + params = model_or_params + + # Parse optimizer name + opt_split = opt.lower().split('_') + opt_name = opt_split[-1] + use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False + + opt_info = self.get_optimizer_info(opt_name) + + # Build optimizer arguments + opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs} + + # Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None. + if lr is not None: + opt_args['lr'] = lr + + # Apply optimizer-specific settings + if opt_info.defaults: + for k, v in opt_info.defaults.items(): + opt_args.setdefault(k, v) + + # timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat. + if opt_info.has_momentum: + opt_args.setdefault('momentum', momentum) + + # Remove commonly used kwargs that aren't always supported + if not opt_info.has_eps: + opt_args.pop('eps', None) + if not opt_info.has_betas: + opt_args.pop('betas', None) + + if foreach is not None: + # Explicitly activate or deactivate multi-tensor foreach impl. + # Not all optimizers support this, and those that do usually default to using + # multi-tensor impl if foreach is left as default 'None' and can be enabled. + opt_args.setdefault('foreach', foreach) + + # Create optimizer + opt_class = self.get_optimizer_class(opt_info, bind_defaults=False) + optimizer = opt_class(params, **opt_args) + + # Apply Lookahead if requested + if use_lookahead: + optimizer = Lookahead(optimizer) + + return optimizer + + +def _register_sgd_variants(registry: OptimizerRegistry) -> None: + """Register SGD-based optimizers""" + sgd_optimizers = [ + OptimInfo( + name='sgd', + opt_class=optim.SGD, + description='Stochastic Gradient Descent with Nesterov momentum (default)', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='momentum', + opt_class=optim.SGD, + description='Stochastic Gradient Descent with classical momentum', + has_eps=False, + has_momentum=True, + defaults={'nesterov': False} + ), + OptimInfo( + name='sgdp', + opt_class=SGDP, + description='SGD with built-in projection to unit norm sphere', + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='sgdw', + opt_class=SGDW, + description='SGD with decoupled weight decay and Nesterov momentum', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + ] + for opt in sgd_optimizers: + registry.register(opt) + + +def _register_adam_variants(registry: OptimizerRegistry) -> None: + """Register Adam-based optimizers""" + adam_optimizers = [ + OptimInfo( + name='adam', + opt_class=optim.Adam, + description='torch.optim Adam (Adaptive Moment Estimation)', + has_betas=True + ), + OptimInfo( + name='adamw', + opt_class=optim.AdamW, + description='torch.optim Adam with decoupled weight decay regularization', + has_betas=True + ), + OptimInfo( + name='adamp', + opt_class=AdamP, + description='Adam with built-in projection to unit norm sphere', + has_betas=True, + defaults={'wd_ratio': 0.01, 'nesterov': True} + ), + OptimInfo( + name='nadam', + opt_class=Nadam, + description='Adam with Nesterov momentum', + has_betas=True + ), + OptimInfo( + name='nadamw', + opt_class=NAdamW, + description='Adam with Nesterov momentum and decoupled weight decay', + has_betas=True + ), + OptimInfo( + name='radam', + opt_class=RAdam, + description='Rectified Adam with variance adaptation', + has_betas=True + ), + OptimInfo( + name='adamax', + opt_class=optim.Adamax, + description='torch.optim Adamax, Adam with infinity norm for more stable updates', + has_betas=True + ), + OptimInfo( + name='adafactor', + opt_class=Adafactor, + description='Memory-efficient implementation of Adam with factored gradients', + ), + OptimInfo( + name='adafactorbv', + opt_class=AdafactorBigVision, + description='Big Vision variant of Adafactor with factored gradients, half precision momentum', + ), + OptimInfo( + name='adopt', + opt_class=Adopt, + description='Modified Adam that can converge with any β2 with the optimal rate', + ), + OptimInfo( + name='adoptw', + opt_class=Adopt, + description='Modified AdamW (decoupled decay) that can converge with any β2 with the optimal rate', + defaults={'decoupled': True} + ), + ] + for opt in adam_optimizers: + registry.register(opt) + + +def _register_lamb_lars(registry: OptimizerRegistry) -> None: + """Register LAMB and LARS variants""" + lamb_lars_optimizers = [ + OptimInfo( + name='lamb', + opt_class=Lamb, + description='Layer-wise Adaptive Moments for batch optimization', + has_betas=True + ), + OptimInfo( + name='lambc', + opt_class=Lamb, + description='LAMB with trust ratio clipping for stability', + has_betas=True, + defaults={'trust_clip': True} + ), + OptimInfo( + name='lars', + opt_class=Lars, + description='Layer-wise Adaptive Rate Scaling', + has_momentum=True + ), + OptimInfo( + name='larc', + opt_class=Lars, + description='LARS with trust ratio clipping for stability', + has_momentum=True, + defaults={'trust_clip': True} + ), + OptimInfo( + name='nlars', + opt_class=Lars, + description='LARS with Nesterov momentum', + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='nlarc', + opt_class=Lars, + description='LARS with Nesterov momentum & trust ratio clipping', + has_momentum=True, + defaults={'nesterov': True, 'trust_clip': True} + ), + ] + for opt in lamb_lars_optimizers: + registry.register(opt) + + +def _register_other_optimizers(registry: OptimizerRegistry) -> None: + """Register miscellaneous optimizers""" + other_optimizers = [ + OptimInfo( + name='adabelief', + opt_class=AdaBelief, + description='Adapts learning rate based on gradient prediction error', + has_betas=True, + defaults={'rectify': False} + ), + OptimInfo( + name='radabelief', + opt_class=AdaBelief, + description='Rectified AdaBelief with variance adaptation', + has_betas=True, + defaults={'rectify': True} + ), + OptimInfo( + name='adadelta', + opt_class=optim.Adadelta, + description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients' + ), + OptimInfo( + name='adagrad', + opt_class=optim.Adagrad, + description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients', + defaults={'eps': 1e-8} + ), + OptimInfo( + name='adan', + opt_class=Adan, + description='Adaptive Nesterov Momentum Algorithm', + defaults={'no_prox': False}, + has_betas=True, + num_betas=3 + ), + OptimInfo( + name='adanw', + opt_class=Adan, + description='Adaptive Nesterov Momentum with decoupled weight decay', + defaults={'no_prox': True}, + has_betas=True, + num_betas=3 + ), + OptimInfo( + name='adahessian', + opt_class=Adahessian, + description='An Adaptive Second Order Optimizer', + has_betas=True, + second_order=True, + ), + OptimInfo( + name='lion', + opt_class=Lion, + description='Evolved Sign Momentum optimizer for improved convergence', + has_eps=False, + has_betas=True + ), + OptimInfo( + name='madgrad', + opt_class=MADGRAD, + description='Momentum-based Adaptive gradient method', + has_momentum=True + ), + OptimInfo( + name='madgradw', + opt_class=MADGRAD, + description='MADGRAD with decoupled weight decay', + has_momentum=True, + defaults={'decoupled_decay': True} + ), + OptimInfo( + name='novograd', + opt_class=NvNovoGrad, + description='Normalized Adam with L2 norm gradient normalization', + has_betas=True + ), + OptimInfo( + name='rmsprop', + opt_class=optim.RMSprop, + description='torch.optim RMSprop, Root Mean Square Propagation', + has_momentum=True, + defaults={'alpha': 0.9} + ), + OptimInfo( + name='rmsproptf', + opt_class=RMSpropTF, + description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation', + has_momentum=True, + defaults={'alpha': 0.9} + ), + ] + for opt in other_optimizers: + registry.register(opt) + registry.register_foreach_default('lion') + + +def _register_apex_optimizers(registry: OptimizerRegistry) -> None: + """Register APEX optimizers (lazy import)""" + apex_optimizers = [ + OptimInfo( + name='fusedsgd', + opt_class='apex.optimizers.FusedSGD', + description='NVIDIA APEX fused SGD implementation for faster training', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='fusedadam', + opt_class='apex.optimizers.FusedAdam', + description='NVIDIA APEX fused Adam implementation', + has_betas=True, + defaults={'adam_w_mode': False} + ), + OptimInfo( + name='fusedadamw', + opt_class='apex.optimizers.FusedAdam', + description='NVIDIA APEX fused AdamW implementation', + has_betas=True, + defaults={'adam_w_mode': True} + ), + OptimInfo( + name='fusedlamb', + opt_class='apex.optimizers.FusedLAMB', + description='NVIDIA APEX fused LAMB implementation', + has_betas=True + ), + OptimInfo( + name='fusednovograd', + opt_class='apex.optimizers.FusedNovoGrad', + description='NVIDIA APEX fused NovoGrad implementation', + has_betas=True, + defaults={'betas': (0.95, 0.98)} + ), + ] + for opt in apex_optimizers: + registry.register(opt) + + +def _register_bnb_optimizers(registry: OptimizerRegistry) -> None: + """Register bitsandbytes optimizers (lazy import)""" + bnb_optimizers = [ + OptimInfo( + name='bnbsgd', + opt_class='bitsandbytes.optim.SGD', + description='bitsandbytes SGD', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='bnbsgd8bit', + opt_class='bitsandbytes.optim.SGD8bit', + description='bitsandbytes 8-bit SGD with dynamic quantization', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='bnbadam', + opt_class='bitsandbytes.optim.Adam', + description='bitsandbytes Adam', + has_betas=True + ), + OptimInfo( + name='bnbadam8bit', + opt_class='bitsandbytes.optim.Adam', + description='bitsandbytes 8-bit Adam with dynamic quantization', + has_betas=True + ), + OptimInfo( + name='bnbadamw', + opt_class='bitsandbytes.optim.AdamW', + description='bitsandbytes AdamW', + has_betas=True + ), + OptimInfo( + name='bnbadamw8bit', + opt_class='bitsandbytes.optim.AdamW', + description='bitsandbytes 8-bit AdamW with dynamic quantization', + has_betas=True + ), + OptimInfo( + 'bnblion', + 'bitsandbytes.optim.Lion', + description='bitsandbytes Lion', + has_eps=False, + has_betas=True + ), + OptimInfo( + 'bnblion8bit', + 'bitsandbytes.optim.Lion8bit', + description='bitsandbytes 8-bit Lion with dynamic quantization', + has_eps=False, + has_betas=True + ), + OptimInfo( + 'bnbademamix', + 'bitsandbytes.optim.AdEMAMix', + description='bitsandbytes AdEMAMix', + has_betas=True, + num_betas=3, + ), + OptimInfo( + 'bnbademamix8bit', + 'bitsandbytes.optim.AdEMAMix8bit', + description='bitsandbytes 8-bit AdEMAMix with dynamic quantization', + has_betas=True, + num_betas=3, + ), + ] + for opt in bnb_optimizers: + registry.register(opt) + + +default_registry = OptimizerRegistry() + +def _register_default_optimizers() -> None: + """Register all default optimizers to the global registry.""" + # Register all optimizer groups + _register_sgd_variants(default_registry) + _register_adam_variants(default_registry) + _register_lamb_lars(default_registry) + _register_other_optimizers(default_registry) + _register_apex_optimizers(default_registry) + _register_bnb_optimizers(default_registry) + + # Register aliases + default_registry.register_alias('nesterov', 'sgd') + default_registry.register_alias('nesterovw', 'sgdw') + + +# Initialize default registry +_register_default_optimizers() + +# Public API + +def list_optimizers( + filter: Union[str, List[str]] = '', + exclude_filters: Optional[List[str]] = None, + with_description: bool = False, +) -> List[Union[str, Tuple[str, str]]]: + """List available optimizer names, optionally filtered. + + List all registered optimizers, with optional filtering using wildcard patterns. + Optimizers can be filtered using include and exclude patterns, and can optionally + return descriptions with each optimizer name. + + Args: + filter: Wildcard style filter string or list of filter strings + (e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for + Adam variants and 8-bit optimizers). Empty string means no filtering. + exclude_filters: Optional list of wildcard patterns to exclude. For example, + ['*8bit', 'fused*'] would exclude 8-bit and fused implementations. + with_description: If True, returns tuples of (name, description) instead of + just names. Descriptions provide brief explanations of optimizer characteristics. + + Returns: + If with_description is False: + List of optimizer names as strings (e.g., ['adam', 'adamw', ...]) + If with_description is True: + List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...]) + + Examples: + >>> list_optimizers() + ['adam', 'adamw', 'sgd', ...] + + >>> list_optimizers(['la*', 'nla*']) # List lamb & lars + ['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars'] + + >>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers + ['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam'] + + >>> list_optimizers(with_description=True) # Get descriptions + [('adabelief', 'Adapts learning rate based on gradient prediction error'), + ('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'), + ('adafactor', 'Memory-efficient implementation of Adam with factored gradients'), + ...] + """ + return default_registry.list_optimizers(filter, exclude_filters, with_description) + + +def get_optimizer_info(name: str) -> OptimInfo: + """Get the OptimInfo for an optimizer. + + Args: + name: Name of the optimizer + + Returns: + OptimInfo configuration + + Raises: + ValueError: If optimizer is not found + """ + return default_registry.get_optimizer_info(name) + + +def get_optimizer_class( + name: str, + bind_defaults: bool = False, +) -> Union[Type[optim.Optimizer], OptimizerCallable]: + """Get optimizer class by name with option to bind default arguments. + + Retrieves the optimizer class or a partial function with default arguments bound. + This allows direct instantiation of optimizers with their default configurations + without going through the full factory. + + Args: + name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd') + bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound. + If False, returns the raw optimizer class. + + Returns: + If bind_defaults is False: + The optimizer class (e.g., torch.optim.Adam) + If bind_defaults is True: + A partial function with default arguments bound + + Raises: + ValueError: If optimizer name is not found in registry + + Examples: + >>> # Get raw optimizer class + >>> Adam = get_optimizer_class('adam') + >>> opt = Adam(model.parameters(), lr=1e-3) + + >>> # Get optimizer with defaults bound + >>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True) + >>> opt = AdamWithDefaults(model.parameters(), lr=1e-3) + + >>> # Get SGD with nesterov momentum default + >>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound + >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9) + """ + return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults) + + +def create_optimizer_v2( + model_or_params: Union[nn.Module, Params], + opt: str = 'sgd', + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + foreach: Optional[bool] = None, + filter_bias_and_bn: bool = True, + layer_decay: Optional[float] = None, + param_group_fn: Optional[Callable[[nn.Module], Params]] = None, + **kwargs: Any, +) -> optim.Optimizer: + """Create an optimizer instance via timm registry. + + Creates and configures an optimizer with appropriate parameter groups and settings. + Supports automatic parameter group creation for weight decay and layer-wise learning + rates, as well as custom parameter grouping. + + Args: + model_or_params: A PyTorch model or an iterable of parameters/parameter groups. + If a model is provided, parameters will be automatically extracted and grouped + based on the other arguments. + opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd'). + Use list_optimizers() to see available options. + lr: Learning rate. If None, will use the optimizer's default. + weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model. + momentum: Momentum factor for optimizers that support it. Only used if the + chosen optimizer accepts a momentum parameter. + foreach: Enable/disable foreach (multi-tensor) implementation if available. + If None, will use optimizer-specific defaults. + filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have + weight decay applied. Only used when model_or_params is a model and + weight_decay > 0. + layer_decay: Optional layer-wise learning rate decay factor. If provided, + learning rates will be scaled by layer_decay^(max_depth - layer_depth). + Only used when model_or_params is a model. + param_group_fn: Optional function to create custom parameter groups. + If provided, other parameter grouping options will be ignored. + **kwargs: Additional optimizer-specific arguments (e.g., betas for Adam). + + Returns: + Configured optimizer instance. + + Examples: + >>> # Basic usage with a model + >>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3) + + >>> # SGD with momentum and weight decay + >>> optimizer = create_optimizer_v2( + ... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4 + ... ) + + >>> # Adam with layer-wise learning rate decay + >>> optimizer = create_optimizer_v2( + ... model, 'adam', lr=1e-3, layer_decay=0.7 + ... ) + + >>> # Custom parameter groups + >>> def group_fn(model): + ... return [ + ... {'params': model.backbone.parameters(), 'lr': 1e-4}, + ... {'params': model.head.parameters(), 'lr': 1e-3} + ... ] + >>> optimizer = create_optimizer_v2( + ... model, 'sgd', param_group_fn=group_fn + ... ) + + Note: + Parameter group handling precedence: + 1. If param_group_fn is provided, it will be used exclusively + 2. If layer_decay is provided, layer-wise groups will be created + 3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created + 4. Otherwise, all parameters will be in a single group + """ + + return default_registry.create_optimizer( + model_or_params, + opt=opt, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + foreach=foreach, + weight_decay_exclude_1d=filter_bias_and_bn, + layer_decay=layer_decay, + param_group_fn=param_group_fn, + **kwargs + ) + + +def optimizer_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. + """ + kwargs = dict( + opt=cfg.opt, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + momentum=cfg.momentum, + ) + if getattr(cfg, 'opt_eps', None) is not None: + kwargs['eps'] = cfg.opt_eps + if getattr(cfg, 'opt_betas', None) is not None: + kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'layer_decay', None) is not None: + kwargs['layer_decay'] = cfg.layer_decay + if getattr(cfg, 'opt_args', None) is not None: + kwargs.update(cfg.opt_args) + if getattr(cfg, 'opt_foreach', None) is not None: + kwargs['foreach'] = cfg.opt_foreach + return kwargs + + +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + return create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + filter_bias_and_bn=filter_bias_and_bn, + ) + diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py new file mode 100644 index 0000000000..a756c5e0c0 --- /dev/null +++ b/timm/optim/_param_groups.py @@ -0,0 +1,131 @@ +import logging +from itertools import islice +from typing import Collection, Optional + +from torch import nn as nn + +from timm.models import group_parameters + + +_logger = logging.getLogger(__name__) + + +def param_groups_weight_decay( + model: nn.Module, + weight_decay: float = 1e-5, + no_weight_decay_list: Collection[str] = (), +): + no_weight_decay_list = set(no_weight_decay_list) + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + no_decay.append(param) + else: + decay.append(param) + + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def _group(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def auto_group_layers(model, layers_per_group=12, num_groups=None): + def _in_head(n, hp): + if not hp: + return True + elif isinstance(hp, (tuple, list)): + return any([n.startswith(hpi) for hpi in hp]) + else: + return n.startswith(hp) + + head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) + names_trunk = [] + names_head = [] + for n, _ in model.named_parameters(): + names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) + + # group non-head layers + num_trunk_layers = len(names_trunk) + if num_groups is not None: + layers_per_group = -(num_trunk_layers // -num_groups) + names_trunk = list(_group(names_trunk, layers_per_group)) + + num_trunk_groups = len(names_trunk) + layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} + layer_map.update({n: num_trunk_groups for n in names_head}) + return layer_map + +_layer_map = auto_group_layers # backward compat + + +def param_groups_layer_decay( + model: nn.Module, + weight_decay: float = 0.05, + no_weight_decay_list: Collection[str] = (), + weight_decay_exclude_1d: bool = True, + layer_decay: float = .75, + end_layer_decay: Optional[float] = None, + verbose: bool = False, +): + """ + Parameter groups for layer-wise lr decay & weight decay + Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + no_weight_decay_list = set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging + param_groups = {} + + if hasattr(model, 'group_matcher'): + # FIXME interface needs more work + layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) + else: + # fallback + layer_map = auto_group_layers(model) + num_layers = max(layer_map.values()) + 1 + layer_max = num_layers - 1 + layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = layer_map.get(name, layer_max) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_groups: + this_scale = layer_scales[layer_id] + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "param_names": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["param_names"].append(name) + param_groups[group_name]["params"].append(param) + + if verbose: + import json + _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) diff --git a/timm/optim/adabelief.py b/timm/optim/adabelief.py index 951d715cc0..dd2abb6ba9 100644 --- a/timm/optim/adabelief.py +++ b/timm/optim/adabelief.py @@ -40,9 +40,18 @@ class AdaBelief(Optimizer): """ def __init__( - self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, - decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): - + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-16, + weight_decay=0, + amsgrad=False, + decoupled_decay=True, + fixed_decay=False, + rectify=True, + degenerated_to_sgd=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -58,9 +67,17 @@ def __init__( param['buffer'] = [[None, None, None] for _ in range(10)] defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, - degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, - fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + degenerated_to_sgd=degenerated_to_sgd, + decoupled_decay=decoupled_decay, + rectify=rectify, + fixed_decay=fixed_decay, + buffer=[[None, None, None] for _ in range(10)] + ) super(AdaBelief, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py index 06057433a9..01c25ff2fb 100644 --- a/timm/optim/adafactor.py +++ b/timm/optim/adafactor.py @@ -2,8 +2,9 @@ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py -Original header/copyright below. +Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers +Original header/copyright below. """ # Copyright (c) Facebook, Inc. and its affiliates. # @@ -15,6 +16,7 @@ class Adafactor(torch.optim.Optimizer): """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` (see https://arxiv.org/abs/1804.04235) @@ -38,16 +40,38 @@ class Adafactor(torch.optim.Optimizer): whether warm-up initialization is being used (default: False) """ - def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, - decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + def __init__( + self, + params, + lr=None, + eps=1e-30, + eps_scale=1e-3, + clip_threshold=1.0, + decay_rate=-0.8, + betas=None, + weight_decay=0.0, + scale_parameter=True, + warmup_init=False, + min_dim_size_to_factor=32, + ): relative_step = not lr if warmup_init and not relative_step: raise ValueError('warmup_init requires relative_step=True') beta1 = None if betas is None else betas[0] # make it compat with standard betas arg - defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, - beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, - relative_step=relative_step, warmup_init=warmup_init) + defaults = dict( + lr=lr, + eps=eps, + eps_scale=eps_scale, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + min_dim_size_to_factor=min_dim_size_to_factor, + ) super(Adafactor, self).__init__(params, defaults) @staticmethod @@ -62,20 +86,34 @@ def _get_lr(param_group, param_state): return param_group['lr'] @staticmethod - def _get_options(param_group, param_shape): - factored = len(param_shape) >= 2 + def _get_options(param_group, param_shape, min_size_to_factor=32): use_first_moment = param_group['beta1'] is not None + factored = None + ndim = len(param_shape) + # Use a simple heuristic to pick factorization row & col, note other PyTorch impl tend to + # always use -2, -1 BUT this will not pick correct dims for convolutions. This is a simple + # approach that should work in most cases, compare to the slightly more involved approach + # in AdafactorBigVision that sorts dims by size, please report if wrong dims chosen. + if ndim > 2 and param_shape[0] > min_size_to_factor and param_shape[1] > min_size_to_factor: + # nD convs in torch are ND + 2 dim weights with leading in/out chs + factored = 0, 1 + elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor: + # if the criteria above didn't match, test trailing dims for eligibility + factored = ndim - 2, ndim - 1 + return factored, use_first_moment @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) - def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): - r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) - c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row): + # from our dim heuristic, always dim_col < dim_row, so col reduction dim for factored row = dim_col + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=dim_col, keepdim=True)).rsqrt_().unsqueeze(dim_row) + c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt() return torch.mul(r_factor, c_factor) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -99,7 +137,11 @@ def step(self, closure=None): state = self.state[p] - factored, use_first_moment = self._get_options(group, grad.shape) + factored_dims, use_first_moment = self._get_options( + group, + grad.shape, + min_size_to_factor=group['min_dim_size_to_factor'], + ) # State Initialization if len(state) == 0: state['step'] = 0 @@ -107,9 +149,12 @@ def step(self, closure=None): if use_first_moment: # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(grad) - if factored: - state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) - state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) + if factored_dims is not None: + dim_col, dim_row = factored_dims + def _remove_dim(shape, dim): + return shape[:dim] + shape[dim + 1:] + state['exp_avg_sq_row'] = torch.zeros(_remove_dim(grad.shape, dim_row)).to(grad) + state['exp_avg_sq_col'] = torch.zeros(_remove_dim(grad.shape, dim_col)).to(grad) else: state['exp_avg_sq'] = torch.zeros_like(grad) @@ -117,7 +162,7 @@ def step(self, closure=None): else: if use_first_moment: state['exp_avg'] = state['exp_avg'].to(grad) - if factored: + if factored_dims is not None: state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) else: @@ -133,15 +178,16 @@ def step(self, closure=None): beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) update = grad ** 2 + group['eps'] - if factored: + if factored_dims is not None: + dim_col, dim_row = factored_dims exp_avg_sq_row = state['exp_avg_sq_row'] exp_avg_sq_col = state['exp_avg_sq_col'] - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=dim_row), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=dim_col), alpha=1.0 - beta2t) # Approximation of exponential moving average of square of gradient - update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row) update.mul_(grad) else: exp_avg_sq = state['exp_avg_sq'] diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py new file mode 100644 index 0000000000..3bb6e9592b --- /dev/null +++ b/timm/optim/adafactor_bv.py @@ -0,0 +1,307 @@ +""" Adafactor (Big Vision variant) for PyTorch + +Adapted from the implementation in big vision: https://github.com/google-research/big_vision + +Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560 + +Adaptation and PyTorch modifications by Ross Wightman +""" + +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +def _get_scalar_dtype(): + """Get the scalar dtype that the optimizer uses for state""" + return torch.float64 + + +def _factored_dims( + shape: Tuple[int, ...], + factored: bool, + min_dim_size_to_factor: int +) -> Optional[tuple[int, int]]: + """Whether to use a factored second moment estimator. + + This function returns a tuple with the two largest axes to reduce over. + If no two dimensions have size >= min_dim_size_to_factor, return None. + + Args: + shape: an input shape + factored: whether to use factored second-moment estimator for > 2d vars. + min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size. + + Returns: + None or a tuple of ints + """ + if not factored or len(shape) < 2: + return None + sorted_dims = sorted(((x, i) for i, x in enumerate(shape))) + if shape[sorted_dims[-2][1]] < min_dim_size_to_factor: + return None + return int(sorted_dims[-2][1]), int(sorted_dims[-1][1]) + + +class AdafactorBigVision(Optimizer): + """ + PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. + + Adapted from https://github.com/google-research/big_vision by Ross Wightman + """ + + def __init__( + self, + params, + lr: float = 1.0, + min_dim_size_to_factor: int = 32, + decay_rate: float = 0.8, + decay_offset: int = 0, + beta2_cap: float = 0.999, + momentum: Optional[float] = 0.9, + momentum_dtype: Union[str, torch.dtype] = torch.bfloat16, + eps: Optional[float] = None, + weight_decay: float = 0.0, + clipping_threshold: Optional[float] = None, + unscaled_wd: bool = False, + *, + foreach: Optional[bool] = False, + ): + if isinstance(momentum_dtype, str): + if momentum_dtype == 'float16': + momentum_dtype = torch.float16 + elif momentum_dtype == 'bfloat16': + momentum_dtype = torch.bfloat16 + else: + assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported' + momentum_dtype = torch.float32 + # FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this. + + defaults = dict( + lr=lr, + min_dim_size_to_factor=min_dim_size_to_factor, + decay_rate=decay_rate, + decay_offset=decay_offset, + beta2_cap=beta2_cap, + momentum=momentum, + momentum_dtype=momentum_dtype, + eps=eps, + weight_decay=weight_decay, + clipping_threshold=clipping_threshold, + unscaled_wd=unscaled_wd, + foreach=foreach, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('foreach', None) + for p in group['params']: + p_state = self.state.get(p, {}) + if len(p_state) != 0 and not torch.is_tensor(p_state['step']): + p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) + + if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']): + # FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast + # the momentum to float32 (it's half precision in the state_dict), need to + # look into this further. Better to override _process_value_according_to_param_policy? + p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype']) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avg_sq_rs = [] + exp_avg_sq_cs = [] + exp_avg_sqs = [] + state_steps = [] + exp_avgs = [] # For momentum + + for p in group['params']: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError("Sparse gradients not supported") + + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + + if len(state) == 0: + # NOTE step on CPU, probably need some more though to make capturable + state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype()) + + shape = p.grad.shape + factored_dims = _factored_dims( + shape, + factored=True, + min_dim_size_to_factor=self.defaults['min_dim_size_to_factor'] + ) + + if factored_dims is not None: + dc, dr = factored_dims + row_shape = list(p.grad.shape) + row_shape[dr] = 1 + col_shape = list(p.grad.shape) + col_shape[dc] = 1 + state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape) + state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape) + else: + state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) + + if self.defaults['momentum'] is not None: + state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype']) + + state_steps.append(state['step']) + exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None)) + exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None)) + exp_avg_sqs.append(state.get('exp_avg_sq', None)) + exp_avgs.append(state.get('exp_avg', None)) + + if group['foreach']: + func = _multi_tensor_adafactor + else: + func = _single_tensor_adafactor + + func( + params=params_with_grad, + grads=grads, + exp_avg_sq_rs=exp_avg_sq_rs, + exp_avg_sq_cs=exp_avg_sq_cs, + exp_avg_sqs=exp_avg_sqs, + exp_avgs=exp_avgs, + state_steps=state_steps, + beta2_decay=group['decay_rate'], + beta2_cap=group['beta2_cap'], + min_dim_size_to_factor=group['min_dim_size_to_factor'], + eps=group['eps'], + lr=group['lr'], + weight_decay=group['weight_decay'], + momentum=group['momentum'], + momentum_dtype=group['momentum_dtype'], + clipping_threshold=group['clipping_threshold'], + unscaled_wd=group['unscaled_wd'], + ) + + return loss + + +def _single_tensor_adafactor( + params: List[Tensor], + grads: List[Tensor], + exp_avg_sq_rs: List[Optional[Tensor]], + exp_avg_sq_cs: List[Optional[Tensor]], + exp_avg_sqs: List[Optional[Tensor]], + exp_avgs: List[Optional[Tensor]], + state_steps: List[Tensor], + *, + beta2_decay: float, + beta2_cap: float, + min_dim_size_to_factor: int, + eps: float, + lr: float, + weight_decay: float, + momentum: Optional[float], + momentum_dtype: Union[str, torch.dtype], + clipping_threshold: Optional[float], + unscaled_wd: bool, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg_sq_r = exp_avg_sq_rs[i] + exp_avg_sq_c = exp_avg_sq_cs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg = exp_avgs[i] + step_t = state_steps[i] + if eps is None: + # default eps for avoiding div by zero, diff from float type eps + eps = 1e-7 if grad.dtype == torch.float16 else 1e-30 + + # Update step + step_t += 1 + beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay)) + one_minus_beta2_t = 1 - beta2_t + + grad_sqr = torch.square(grad) + eps + # NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach + if exp_avg_sq is None: + # factorized second moment + dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor) + exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t) + exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t) + + reduce_dc = dc - 1 if dc > dr else dc + row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True) + row_factor = (exp_avg_sq_r / row_col_mean).rsqrt() + col_factor = exp_avg_sq_c.rsqrt() + + update = grad * row_factor * col_factor + else: + # non-factorized second moment + assert exp_avg_sq_r is None and exp_avg_sq_c is None + exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t) + update = grad * exp_avg_sq.rsqrt() + + # Clip by RMS value + if clipping_threshold is not None: + denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0) + update.div_(denom) + + # Apply momentum (in different dtype) + if momentum is not None and exp_avg is not None: + if momentum_dtype != grad.dtype: + exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema + update = exp_avg.to(grad.dtype) + else: + exp_avg.lerp_(update, 1 - momentum) # ema + update = exp_avg.clone() + + # Scale by learning rate + update.mul_(lr) + + # Perform weight decay + if weight_decay != 0: + if unscaled_wd: + # match big vision impl, 'fully decoupled' decay w/o LR scaling + param.mul_(1. - weight_decay) + else: + # match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR + param.mul_(1. - lr * weight_decay) + + # Update parameters + param.add_(update, alpha=-1.0) + + +def _multi_tensor_adafactor( + params: List[Tensor], + grads: List[Tensor], + exp_avg_sq_rs: List[Optional[Tensor]], + exp_avg_sq_cs: List[Optional[Tensor]], + exp_avg_sqs: List[Optional[Tensor]], + exp_avgs: List[Optional[Tensor]], + state_steps: List[Tensor], + *, + beta2_decay: float, + beta2_cap: float, + min_dim_size_to_factor: int, + eps: float, + lr: float, + weight_decay: float, + momentum: Optional[float], + momentum_dtype: Union[str, torch.dtype], + clipping_threshold: Optional[float], + unscaled_wd: bool, +): + # FIXME TODO + assert False, 'multi-tensor fn (foreach=True) not implemented yet' diff --git a/timm/optim/adahessian.py b/timm/optim/adahessian.py index 985c67ca68..9067cc66cf 100644 --- a/timm/optim/adahessian.py +++ b/timm/optim/adahessian.py @@ -23,8 +23,18 @@ class Adahessian(torch.optim.Optimizer): n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) """ - def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, - hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + def __init__( + self, + params, + lr=0.1, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + hessian_power=1.0, + update_each=1, + n_samples=1, + avg_conv_kernel=False, + ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: @@ -44,7 +54,13 @@ def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0. self.seed = 2147483647 self.generator = torch.Generator().manual_seed(self.seed) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + hessian_power=hessian_power, + ) super(Adahessian, self).__init__(params, defaults) for p in self.get_params(): diff --git a/timm/optim/adamp.py b/timm/optim/adamp.py index ee187633ab..5a9ac3395d 100644 --- a/timm/optim/adamp.py +++ b/timm/optim/adamp.py @@ -41,11 +41,26 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): class AdamP(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + delta=0.1, + wd_ratio=0.1, + nesterov=False, + ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + delta=delta, + wd_ratio=wd_ratio, + nesterov=nesterov, + ) super(AdamP, self).__init__(params, defaults) @torch.no_grad() diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py index 66478bc6ef..b755a57ca8 100644 --- a/timm/optim/adamw.py +++ b/timm/optim/adamw.py @@ -36,8 +36,16 @@ class AdamW(Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + ): + # NOTE: deprecated in favour of builtin torch.optim.AdamW if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -46,8 +54,13 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) super(AdamW, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/adopt.py b/timm/optim/adopt.py new file mode 100644 index 0000000000..486cb6263c --- /dev/null +++ b/timm/optim/adopt.py @@ -0,0 +1,497 @@ +""" ADOPT PyTorch Optimizer + +ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate: https://arxiv.org/abs/2411.02853 + +Modified for reduced dependencies on PyTorch internals from original at: https://github.com/iShohei220/adopt + +@inproceedings{taniguchi2024adopt, + author={Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka}, + booktitle = {Advances in Neural Information Processing Systems}, + title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate}, + year = {2024} +} + +""" + +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torch.optim.optimizer import Optimizer + +__all__ = ["Adopt", "adopt"] + +def _view_as_real(params, *state_and_grads): + for i, p in enumerate(params): + if torch.is_complex(p): + params[i] = torch.view_as_real(params[i]) + for s in state_and_grads: + s[i] = torch.view_as_real(s[i]) + + +def _get_scalar_dtype(is_fused=None): + if is_fused: + return torch.float32 + return ( + torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 + ) + + +def _is_compiling(): + return torch.compiler.is_compiling() if hasattr(torch, 'compiler') else False + + +def _get_value(x): + # item is significantly faster than a cpu tensor in eager mode + if not torch.jit.is_scripting() and _is_compiling(): + return x + else: + return x.item() if isinstance(x, torch.Tensor) else x + + +class Adopt(Optimizer): + """ + ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate: https://arxiv.org/abs/2411.02853 + + """ + def __init__( + self, + params, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.9999), + eps: float = 1e-6, + weight_decay: float = 0.0, + decoupled: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): + if isinstance(lr, Tensor): + if foreach and not capturable: + raise ValueError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + if lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + decoupled=decoupled, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(), + device=p.device, + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + "ADOPT does not support sparse gradients" + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(), + device=p.grad.device, + ) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if group["differentiable"] and state["step"].requires_grad: + raise RuntimeError( + "`requires_grad` is not supported for `step` in differentiable mode" + ) + + # Foreach without capturable does not support a tensor lr + if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + state_steps.append(state["step"]) + return has_complex + + #@_use_grad_for_differentiable # FIXME internal context mgr, can't use + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = group["betas"] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + ) + + adopt( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + decoupled=group["decoupled"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def _single_tensor_adopt( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + decoupled: bool, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if capturable and not _is_compiling(): + from torch.optim.optimizer import _get_capturable_supported_devices + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + # update step + step_t += 1 + + if weight_decay != 0: + if decoupled: + param.add_(param, alpha=-lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + if exp_avg is not None: + exp_avg = torch.view_as_real(exp_avg) + if exp_avg_sq is not None: + exp_avg_sq = torch.view_as_real(exp_avg_sq) + param = torch.view_as_real(param) + + step = step_t if capturable or differentiable else _get_value(step_t) + if step == 1: + exp_avg_sq.addcmul_(grad, grad.conj()) + continue + + denom = torch.clamp(exp_avg_sq.sqrt(), eps) + if step == 2: + exp_avg.addcdiv_(grad, denom) + else: + exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + + param.add_(exp_avg, alpha=-lr) + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + + +def _multi_tensor_adopt( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + decoupled: bool, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if capturable and not _is_compiling(): + from torch.optim.optimizer import _get_capturable_supported_devices + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert grad_scale is None and found_inf is None + + assert not differentiable, "_foreach ops don't support autograd" + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + # Handle complex parameters + if has_complex: + _view_as_real( + device_params, device_grads, device_exp_avgs, device_exp_avg_sqs + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not _is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + if decoupled: + torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay) + else: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if device_state_steps[0] == 1: + torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) + continue + + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps) + + if device_state_steps[0] == 2: + torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt) + else: + torch._foreach_mul_(device_exp_avgs, beta1) + torch._foreach_addcdiv_( + device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1 + ) + + torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2 + ) + + +#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use +def adopt( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + decoupled: bool, + eps: float, + maximize: bool, +): + r"""Functional API that performs ADOPT algorithm computation. + + """ + if foreach is None: + foreach = False + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not _is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adopt + else: + func = _single_tensor_adopt + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + decoupled=decoupled, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index 12c7c49b8a..9d3a3421df 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -85,14 +85,49 @@ class Lamb(Optimizer): """ def __init__( - self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, - weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False, + ): defaults = dict( - lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, - trust_clip=trust_clip, always_adapt=always_adapt) + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + trust_clip=trust_clip, + always_adapt=always_adapt, + ) super().__init__(params, defaults) + def _get_clip_grad_norm(self): + max_grad_norm = self.defaults['max_grad_norm'] + if max_grad_norm is None: + return None + + norms = [] + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.') + norms.append(torch.linalg.vector_norm(grad)) + global_norm = torch.linalg.vector_norm(torch.stack(norms)) + clip_global_norm = (global_norm / max_grad_norm).clamp_(min=1.0) + return clip_global_norm + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -105,26 +140,7 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - device = self.param_groups[0]['params'][0].device - one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - global_grad_norm = torch.zeros(1, device=device) - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') - global_grad_norm.add_(grad.pow(2).sum()) - - global_grad_norm = torch.sqrt(global_grad_norm) - # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes - # scalar types properly https://github.com/pytorch/pytorch/issues/9190 - max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) - clip_global_grad_norm = torch.where( - global_grad_norm > max_grad_norm, - global_grad_norm / max_grad_norm, - one_tensor) + clip_grad_norm = self._get_clip_grad_norm() # None if disabled for group in self.param_groups: bias_correction = 1 if group['bias_correction'] else 0 @@ -148,7 +164,11 @@ def step(self, closure=None): for p in group['params']: if p.grad is None: continue - grad = p.grad.div_(clip_global_grad_norm) + grad = p.grad + + if clip_grad_norm is not None: + grad.div_(clip_grad_norm) + state = self.state[p] # State initialization @@ -176,15 +196,17 @@ def step(self, closure=None): # excluded from weight decay, unless always_adapt == True, then always enabled. w_norm = p.norm(2.0) g_norm = update.norm(2.0) + trust_ratio = w_norm / g_norm # FIXME nested where required since logical and/or not working in PT XLA + # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero trust_ratio = torch.where( w_norm > 0, - torch.where(g_norm > 0, w_norm / g_norm, one_tensor), - one_tensor, + torch.where(g_norm > 0, trust_ratio, 1.0), + 1.0, ) if group['trust_clip']: # LAMBC trust clipping, upper bound fixed at one - trust_ratio = torch.minimum(trust_ratio, one_tensor) + trust_ratio = torch.clamp(trust_ratio, max=1.0) update.mul_(trust_ratio) p.add_(update, alpha=-group['lr']) diff --git a/timm/optim/lars.py b/timm/optim/lars.py index 38ca9e0b5c..d49efc6d0e 100644 --- a/timm/optim/lars.py +++ b/timm/optim/lars.py @@ -84,9 +84,6 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - device = self.param_groups[0]['params'][0].device - one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] @@ -107,13 +104,14 @@ def step(self, closure=None): g_norm = grad.norm(2.0) trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) # FIXME nested where required since logical and/or not working in PT XLA + # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero trust_ratio = torch.where( w_norm > 0, - torch.where(g_norm > 0, trust_ratio, one_tensor), - one_tensor, + torch.where(g_norm > 0, trust_ratio, 1.0), + 1.0, ) if group['trust_clip']: - trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + trust_ratio = torch.clamp(trust_ratio / group['lr'], max=1.0) grad.add_(p, alpha=weight_decay) grad.mul_(trust_ratio) diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 4d8086424d..3bcb273cac 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -137,7 +137,7 @@ def lion( """ if foreach is None: # Placeholder for more complex foreach logic to be added when value is not set - foreach = False + foreach = True if foreach and torch.jit.is_scripting(): raise RuntimeError('torch.jit.script not supported with foreach optimizers') diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py index a76713bf27..8e449dce3d 100644 --- a/timm/optim/madgrad.py +++ b/timm/optim/madgrad.py @@ -71,7 +71,12 @@ def __init__( raise ValueError(f"Eps must be non-negative") defaults = dict( - lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) + lr=lr, + eps=eps, + momentum=momentum, + weight_decay=weight_decay, + decoupled_decay=decoupled_decay, + ) super().__init__(params, defaults) @property diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py index 4e911420ef..892262cfac 100644 --- a/timm/optim/nadam.py +++ b/timm/optim/nadam.py @@ -27,8 +27,15 @@ class Nadam(Optimizer): NOTE: Has potential issues but does work well on some problems. """ - def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, schedule_decay=4e-3): + def __init__( + self, + params, + lr=2e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + schedule_decay=4e-3, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) defaults = dict( diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py index fda3f4a620..068e5aa2c1 100644 --- a/timm/optim/nvnovograd.py +++ b/timm/optim/nvnovograd.py @@ -29,8 +29,16 @@ class NvNovoGrad(Optimizer): (default: False) """ - def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, - weight_decay=0, grad_averaging=False, amsgrad=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.95, 0.98), + eps=1e-8, + weight_decay=0, + grad_averaging=False, + amsgrad=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -39,10 +47,14 @@ def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - amsgrad=amsgrad) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad, + ) super(NvNovoGrad, self).__init__(params, defaults) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 8187b55a1b..a4227a9848 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,423 +1,7 @@ -""" Optimizer Factory w/ Custom Weight Decay -Hacked together by / Copyright 2021 Ross Wightman -""" -import logging -from itertools import islice -from typing import Optional, Callable, Tuple +# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/ -import torch -import torch.nn as nn -import torch.optim as optim +from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group -from timm.models import group_parameters - -from .adabelief import AdaBelief -from .adafactor import Adafactor -from .adahessian import Adahessian -from .adamp import AdamP -from .adan import Adan -from .lamb import Lamb -from .lars import Lars -from .lion import Lion -from .lookahead import Lookahead -from .madgrad import MADGRAD -from .nadam import Nadam -from .nadamw import NAdamW -from .nvnovograd import NvNovoGrad -from .radam import RAdam -from .rmsprop_tf import RMSpropTF -from .sgdp import SGDP -from .sgdw import SGDW - - -_logger = logging.getLogger(__name__) - - -# optimizers to default to multi-tensor -_DEFAULT_FOREACH = { - 'lion', -} - - -def param_groups_weight_decay( - model: nn.Module, - weight_decay=1e-5, - no_weight_decay_list=() -): - no_weight_decay_list = set(no_weight_decay_list) - decay = [] - no_decay = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: - no_decay.append(param) - else: - decay.append(param) - - return [ - {'params': no_decay, 'weight_decay': 0.}, - {'params': decay, 'weight_decay': weight_decay}] - - -def _group(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - - -def _layer_map(model, layers_per_group=12, num_groups=None): - def _in_head(n, hp): - if not hp: - return True - elif isinstance(hp, (tuple, list)): - return any([n.startswith(hpi) for hpi in hp]) - else: - return n.startswith(hp) - - head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) - names_trunk = [] - names_head = [] - for n, _ in model.named_parameters(): - names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) - - # group non-head layers - num_trunk_layers = len(names_trunk) - if num_groups is not None: - layers_per_group = -(num_trunk_layers // -num_groups) - names_trunk = list(_group(names_trunk, layers_per_group)) - - num_trunk_groups = len(names_trunk) - layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} - layer_map.update({n: num_trunk_groups for n in names_head}) - return layer_map - - -def param_groups_layer_decay( - model: nn.Module, - weight_decay: float = 0.05, - no_weight_decay_list: Tuple[str] = (), - layer_decay: float = .75, - end_layer_decay: Optional[float] = None, - verbose: bool = False, -): - """ - Parameter groups for layer-wise lr decay & weight decay - Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 - """ - no_weight_decay_list = set(no_weight_decay_list) - param_group_names = {} # NOTE for debugging - param_groups = {} - - if hasattr(model, 'group_matcher'): - # FIXME interface needs more work - layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) - else: - # fallback - layer_map = _layer_map(model) - num_layers = max(layer_map.values()) + 1 - layer_max = num_layers - 1 - layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) - - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - # no decay: all 1D parameters and model specific ones - if param.ndim == 1 or name in no_weight_decay_list: - g_decay = "no_decay" - this_decay = 0. - else: - g_decay = "decay" - this_decay = weight_decay - - layer_id = layer_map.get(name, layer_max) - group_name = "layer_%d_%s" % (layer_id, g_decay) - - if group_name not in param_groups: - this_scale = layer_scales[layer_id] - param_group_names[group_name] = { - "lr_scale": this_scale, - "weight_decay": this_decay, - "param_names": [], - } - param_groups[group_name] = { - "lr_scale": this_scale, - "weight_decay": this_decay, - "params": [], - } - - param_group_names[group_name]["param_names"].append(name) - param_groups[group_name]["params"].append(param) - - if verbose: - import json - _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) - - return list(param_groups.values()) - - -def optimizer_kwargs(cfg): - """ cfg/argparse to kwargs helper - Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. - """ - kwargs = dict( - opt=cfg.opt, - lr=cfg.lr, - weight_decay=cfg.weight_decay, - momentum=cfg.momentum, - ) - if getattr(cfg, 'opt_eps', None) is not None: - kwargs['eps'] = cfg.opt_eps - if getattr(cfg, 'opt_betas', None) is not None: - kwargs['betas'] = cfg.opt_betas - if getattr(cfg, 'layer_decay', None) is not None: - kwargs['layer_decay'] = cfg.layer_decay - if getattr(cfg, 'opt_args', None) is not None: - kwargs.update(cfg.opt_args) - if getattr(cfg, 'opt_foreach', None) is not None: - kwargs['foreach'] = cfg.opt_foreach - return kwargs - - -def create_optimizer(args, model, filter_bias_and_bn=True): - """ Legacy optimizer factory for backwards compatibility. - NOTE: Use create_optimizer_v2 for new code. - """ - return create_optimizer_v2( - model, - **optimizer_kwargs(cfg=args), - filter_bias_and_bn=filter_bias_and_bn, - ) - - -def create_optimizer_v2( - model_or_params, - opt: str = 'sgd', - lr: Optional[float] = None, - weight_decay: float = 0., - momentum: float = 0.9, - foreach: Optional[bool] = None, - filter_bias_and_bn: bool = True, - layer_decay: Optional[float] = None, - param_group_fn: Optional[Callable] = None, - **kwargs, -): - """ Create an optimizer. - - TODO currently the model is passed in and all parameters are selected for optimization. - For more general use an interface that allows selection of parameters to optimize and lr groups, one of: - * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion - * expose the parameters interface and leave it up to caller - - Args: - model_or_params (nn.Module): model containing parameters to optimize - opt: name of optimizer to create - lr: initial learning rate - weight_decay: weight decay to apply in optimizer - momentum: momentum for momentum based optimizers (others may use betas via kwargs) - foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None - filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay - **kwargs: extra optimizer specific kwargs to pass through - - Returns: - Optimizer - """ - if isinstance(model_or_params, nn.Module): - # a model was passed in, extract parameters and add weight decays to appropriate layers - no_weight_decay = {} - if hasattr(model_or_params, 'no_weight_decay'): - no_weight_decay = model_or_params.no_weight_decay() - - if param_group_fn: - parameters = param_group_fn(model_or_params) - elif layer_decay is not None: - parameters = param_groups_layer_decay( - model_or_params, - weight_decay=weight_decay, - layer_decay=layer_decay, - no_weight_decay_list=no_weight_decay, - ) - weight_decay = 0. - elif weight_decay and filter_bias_and_bn: - parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) - weight_decay = 0. - else: - parameters = model_or_params.parameters() - else: - # iterable of parameters or param groups passed in - parameters = model_or_params - - opt_lower = opt.lower() - opt_split = opt_lower.split('_') - opt_lower = opt_split[-1] - - if opt_lower.startswith('fused'): - try: - from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD - has_apex = True - except ImportError: - has_apex = False - assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' - - if opt_lower.startswith('bnb'): - try: - import bitsandbytes as bnb - has_bnb = True - except ImportError: - has_bnb = False - assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers' - - opt_args = dict(weight_decay=weight_decay, **kwargs) - - if lr is not None: - opt_args.setdefault('lr', lr) - - if foreach is None: - if opt in _DEFAULT_FOREACH: - opt_args.setdefault('foreach', True) - else: - opt_args['foreach'] = foreach - - # basic SGD & related - if opt_lower == 'sgd' or opt_lower == 'nesterov': - # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons - opt_args.pop('eps', None) - optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'momentum': - opt_args.pop('eps', None) - optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) - elif opt_lower == 'sgdp': - optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'sgdw' or opt_lower == 'nesterovw': - # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons - opt_args.pop('eps', None) - optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'momentumw': - opt_args.pop('eps', None) - optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args) - - # adaptive - elif opt_lower == 'adam': - optimizer = optim.Adam(parameters, **opt_args) - elif opt_lower == 'adamw': - optimizer = optim.AdamW(parameters, **opt_args) - elif opt_lower == 'adamp': - optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) - elif opt_lower == 'nadam': - try: - # NOTE PyTorch >= 1.10 should have native NAdam - optimizer = optim.Nadam(parameters, **opt_args) - except AttributeError: - optimizer = Nadam(parameters, **opt_args) - elif opt_lower == 'nadamw': - optimizer = NAdamW(parameters, **opt_args) - elif opt_lower == 'radam': - optimizer = RAdam(parameters, **opt_args) - elif opt_lower == 'adamax': - optimizer = optim.Adamax(parameters, **opt_args) - elif opt_lower == 'adabelief': - optimizer = AdaBelief(parameters, rectify=False, **opt_args) - elif opt_lower == 'radabelief': - optimizer = AdaBelief(parameters, rectify=True, **opt_args) - elif opt_lower == 'adadelta': - optimizer = optim.Adadelta(parameters, **opt_args) - elif opt_lower == 'adagrad': - opt_args.setdefault('eps', 1e-8) - optimizer = optim.Adagrad(parameters, **opt_args) - elif opt_lower == 'adafactor': - optimizer = Adafactor(parameters, **opt_args) - elif opt_lower == 'adanp': - optimizer = Adan(parameters, no_prox=False, **opt_args) - elif opt_lower == 'adanw': - optimizer = Adan(parameters, no_prox=True, **opt_args) - elif opt_lower == 'lamb': - optimizer = Lamb(parameters, **opt_args) - elif opt_lower == 'lambc': - optimizer = Lamb(parameters, trust_clip=True, **opt_args) - elif opt_lower == 'larc': - optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) - elif opt_lower == 'lars': - optimizer = Lars(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'nlarc': - optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) - elif opt_lower == 'nlars': - optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'madgrad': - optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'madgradw': - optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) - elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': - optimizer = NvNovoGrad(parameters, **opt_args) - elif opt_lower == 'rmsprop': - optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) - elif opt_lower == 'rmsproptf': - optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) - elif opt_lower == 'lion': - opt_args.pop('eps', None) - optimizer = Lion(parameters, **opt_args) - - # second order - elif opt_lower == 'adahessian': - optimizer = Adahessian(parameters, **opt_args) - - # NVIDIA fused optimizers, require APEX to be installed - elif opt_lower == 'fusedsgd': - opt_args.pop('eps', None) - optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'fusedmomentum': - opt_args.pop('eps', None) - optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) - elif opt_lower == 'fusedadam': - optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) - elif opt_lower == 'fusedadamw': - optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) - elif opt_lower == 'fusedlamb': - optimizer = FusedLAMB(parameters, **opt_args) - elif opt_lower == 'fusednovograd': - opt_args.setdefault('betas', (0.95, 0.98)) - optimizer = FusedNovoGrad(parameters, **opt_args) - - # bitsandbytes optimizers, require bitsandbytes to be installed - elif opt_lower == 'bnbsgd': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'bnbsgd8bit': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'bnbmomentum': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'bnbmomentum8bit': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'bnbadam': - optimizer = bnb.optim.Adam(parameters, **opt_args) - elif opt_lower == 'bnbadam8bit': - optimizer = bnb.optim.Adam8bit(parameters, **opt_args) - elif opt_lower == 'bnbadamw': - optimizer = bnb.optim.AdamW(parameters, **opt_args) - elif opt_lower == 'bnbadamw8bit': - optimizer = bnb.optim.AdamW8bit(parameters, **opt_args) - elif opt_lower == 'bnblamb': - optimizer = bnb.optim.LAMB(parameters, **opt_args) - elif opt_lower == 'bnblamb8bit': - optimizer = bnb.optim.LAMB8bit(parameters, **opt_args) - elif opt_lower == 'bnblars': - optimizer = bnb.optim.LARS(parameters, **opt_args) - elif opt_lower == 'bnblarsb8bit': - optimizer = bnb.optim.LAMB8bit(parameters, **opt_args) - elif opt_lower == 'bnblion': - optimizer = bnb.optim.Lion(parameters, **opt_args) - elif opt_lower == 'bnblion8bit': - optimizer = bnb.optim.Lion8bit(parameters, **opt_args) - - else: - assert False and "Invalid optimizer" - raise ValueError - - if len(opt_split) > 1: - if opt_split[0] == 'lookahead': - optimizer = Lookahead(optimizer) - - return optimizer +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning) diff --git a/timm/optim/radam.py b/timm/optim/radam.py index eb8d22e06c..d6f8d30a67 100644 --- a/timm/optim/radam.py +++ b/timm/optim/radam.py @@ -9,10 +9,21 @@ class RAdam(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - buffer=[[None, None, None] for _ in range(10)]) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)] + ) super(RAdam, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py index 0817887db3..8511b3b482 100644 --- a/timm/optim/rmsprop_tf.py +++ b/timm/optim/rmsprop_tf.py @@ -45,8 +45,18 @@ class RMSpropTF(Optimizer): """ - def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, - decoupled_decay=False, lr_in_momentum=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.9, + eps=1e-10, + weight_decay=0, + momentum=0., + centered=False, + decoupled_decay=False, + lr_in_momentum=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -59,8 +69,15 @@ def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, moment raise ValueError("Invalid alpha value: {}".format(alpha)) defaults = dict( - lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, - decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + decoupled_decay=decoupled_decay, + lr_in_momentum=lr_in_momentum, + ) super(RMSpropTF, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/sgdp.py b/timm/optim/sgdp.py index baf05fa55c..87b89f6f0b 100644 --- a/timm/optim/sgdp.py +++ b/timm/optim/sgdp.py @@ -17,11 +17,28 @@ class SGDP(Optimizer): - def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + def __init__( + self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + eps=1e-8, + delta=0.1, + wd_ratio=0.1 + ): defaults = dict( - lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, - nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + eps=eps, + delta=delta, + wd_ratio=wd_ratio, + ) super(SGDP, self).__init__(params, defaults) @torch.no_grad() diff --git a/timm/optim/sgdw.py b/timm/optim/sgdw.py index b3d2c12f03..c5b44063d6 100644 --- a/timm/optim/sgdw.py +++ b/timm/optim/sgdw.py @@ -35,10 +35,15 @@ def __init__( raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( - lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, - maximize=maximize, foreach=foreach, - differentiable=differentiable) + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) diff --git a/train.py b/train.py index 5179e31d2e..ff11622bb6 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse +import copy import importlib import json import logging @@ -554,6 +555,13 @@ def main(): **optimizer_kwargs(cfg=args), **args.opt_kwargs, ) + if utils.is_primary(args): + defaults = copy.deepcopy(optimizer.defaults) + defaults['weight_decay'] = args.weight_decay # this isn't stored in optimizer.defaults + defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()]) + logging.info( + f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}' + ) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing