Skip to content
Merged
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Additional optimizers that conform to Keras API."""

from tensorflow_addons.optimizers.constants import KerasLegacyOptimizer
from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper
from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
from tensorflow_addons.optimizers.cyclical_learning_rate import CyclicalLearningRate
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typing import Union, Callable, Dict


@tf.keras.utils.register_keras_serializable(package="Addons")
class AdaBelief(tf.keras.optimizers.Optimizer):
class AdaBelief(KerasLegacyOptimizer):
"""Variant of the Adam optimizer.

It achieves fast convergence as Adam and generalization comparable to SGD.
Expand Down
11 changes: 7 additions & 4 deletions tensorflow_addons/optimizers/average_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import warnings

import tensorflow as tf
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from tensorflow_addons.utils import types

from typeguard import typechecked


class AveragedOptimizerWrapper(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta):
class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
@typechecked
def __init__(
self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs
Expand All @@ -32,9 +32,12 @@ def __init__(
if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)

if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
if not isinstance(
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer"
"optimizer is not an object of tf.keras.optimizers.Optimizer "
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
)

self._optimizer = optimizer
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/optimizers/cocob.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from typeguard import typechecked
import tensorflow as tf

from tensorflow_addons.optimizers import KerasLegacyOptimizer


@tf.keras.utils.register_keras_serializable(package="Addons")
class COCOB(tf.keras.optimizers.Optimizer):
class COCOB(KerasLegacyOptimizer):
"""Optimizer that implements COCOB Backprop Algorithm

Reference:
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/conditional_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""Conditional Gradient optimizer."""

import tensorflow as tf
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from tensorflow_addons.utils.types import FloatTensorLike

from typeguard import typechecked
from typing import Union, Callable


@tf.keras.utils.register_keras_serializable(package="Addons")
class ConditionalGradient(tf.keras.optimizers.Optimizer):
class ConditionalGradient(KerasLegacyOptimizer):
"""Optimizer that implements the Conditional Gradient optimization.

This optimizer helps handle constraints well.
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_addons/optimizers/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import importlib
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copyright header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
KerasLegacyOptimizer = tf.keras.optimizers.legacy.Optimizer
else:
KerasLegacyOptimizer = tf.keras.optimizers.Optimizer
8 changes: 4 additions & 4 deletions tensorflow_addons/optimizers/cyclical_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down Expand Up @@ -286,7 +286,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from typing import List, Union

import tensorflow as tf

from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class MultiOptimizer(tf.keras.optimizers.Optimizer):
class MultiOptimizer(KerasLegacyOptimizer):
"""Multi Optimizer Wrapper for Discriminative Layer Training.

Creates a wrapper around a set of instantiated optimizer layer pairs.
Expand All @@ -30,7 +32,7 @@ class MultiOptimizer(tf.keras.optimizers.Optimizer):
Each optimizer will optimize only the weights associated with its paired layer.
This can be used to implement discriminative layer training by assigning
different learning rates to each optimizer layer pair.
`(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
`(tf.keras.optimizers.legacy.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
Please note that the layers must be instantiated before instantiating the optimizer.

Args:
Expand Down Expand Up @@ -130,7 +132,7 @@ def get_config(self):
@classmethod
def create_optimizer_spec(
cls,
optimizer: tf.keras.optimizers.Optimizer,
optimizer: KerasLegacyOptimizer,
layers_or_model: Union[
tf.keras.Model,
tf.keras.Sequential,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from typeguard import typechecked

import tensorflow as tf
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes


@tf.keras.utils.register_keras_serializable(package="Addons")
class LAMB(tf.keras.optimizers.Optimizer):
class LAMB(KerasLegacyOptimizer):
"""Optimizer that implements the Layer-wise Adaptive Moments (LAMB).

See paper [Large Batch Optimization for Deep Learning: Training BERT
Expand Down
9 changes: 8 additions & 1 deletion tensorflow_addons/optimizers/lazy_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
original Adam algorithm, and may lead to different empirical results.
"""

import importlib
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from typeguard import typechecked
from typing import Union, Callable


if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
adam_optimizer_class = tf.keras.optimizers.legacy.Adam
else:
adam_optimizer_class = tf.keras.optimizers.Adam


@tf.keras.utils.register_keras_serializable(package="Addons")
class LazyAdam(tf.keras.optimizers.Adam):
class LazyAdam(adam_optimizer_class):
"""Variant of the Adam optimizer that handles sparse updates more
efficiently.

Expand Down
10 changes: 7 additions & 3 deletions tensorflow_addons/optimizers/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import tensorflow as tf
from tensorflow_addons.utils import types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't required to update types.Optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! done


from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class Lookahead(tf.keras.optimizers.Optimizer):
class Lookahead(KerasLegacyOptimizer):
"""This class allows to extend optimizers with the lookahead mechanism.

The mechanism is proposed by Michael R. Zhang et.al in the paper
Expand Down Expand Up @@ -71,9 +72,12 @@ def __init__(

if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
if not isinstance(
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer"
"optimizer is not an object of tf.keras.optimizers.Optimizer "
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
)

self._optimizer = optimizer
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/optimizers/moving_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
r"""Construct a new MovingAverage optimizer.

Args:
optimizer: str or `tf.keras.optimizers.Optimizer` that will be
optimizer: str or `tf.keras.optimizers.legacy.Optimizer` that will be
used to compute and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/optimizers/novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typing import Union, Callable
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class NovoGrad(tf.keras.optimizers.Optimizer):
class NovoGrad(KerasLegacyOptimizer):
"""Optimizer that implements NovoGrad.

The NovoGrad Optimizer was first proposed in [Stochastic Gradient
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/proximal_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.optimizers import KerasLegacyOptimizer
from tensorflow_addons.utils.types import FloatTensorLike


@tf.keras.utils.register_keras_serializable(package="Addons")
class ProximalAdagrad(tf.keras.optimizers.Optimizer):
class ProximalAdagrad(KerasLegacyOptimizer):
"""Optimizer that implements the Proximal Adagrad algorithm.

References:
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/rectified_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typing import Union, Callable, Dict
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class RectifiedAdam(tf.keras.optimizers.Optimizer):
class RectifiedAdam(KerasLegacyOptimizer):
"""Variant of the Adam optimizer whose adaptive learning rate is rectified
so as to have a consistent variance.

Expand Down
7 changes: 3 additions & 4 deletions tensorflow_addons/optimizers/tests/standard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf

from tensorflow_addons import optimizers
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from tensorflow_addons.utils.test_utils import discover_classes

class_exceptions = [
Expand All @@ -29,12 +30,10 @@
"ConditionalGradient", # is wrapper
"Lookahead", # is wrapper
"MovingAverage", # is wrapper
"KerasLegacyOptimizer", # is a constantc
]


classes_to_test = discover_classes(
optimizers, tf.keras.optimizers.Optimizer, class_exceptions
)
classes_to_test = discover_classes(optimizers, KerasLegacyOptimizer, class_exceptions)


@pytest.mark.parametrize("optimizer", classes_to_test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for optimizers with weight decay."""

import importlib
import numpy as np
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -401,13 +402,17 @@ def test_var_list_with_exclude_list_sgdw(dtype):
)


if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
optimizer_class = tf.keras.optimizers.legacy.SGD
else:
optimizer_class = tf.keras.optimizers.SGD


@pytest.mark.parametrize(
"optimizer",
[
weight_decay_optimizers.SGDW,
weight_decay_optimizers.extend_with_decoupled_weight_decay(
tf.keras.optimizers.SGD
),
weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class),
],
)
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
Expand Down
13 changes: 11 additions & 2 deletions tensorflow_addons/optimizers/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Base class to make optimizers weight decay ready."""

import importlib
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
Expand Down Expand Up @@ -261,10 +262,18 @@ def _do_use_weight_decay(self, var):
return var.ref() in self._decay_var_list


if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
keras_legacy_optimizer = Union[
tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer
]
else:
keras_legacy_optimizer = tf.keras.optimizers.Optimizer


@typechecked
def extend_with_decoupled_weight_decay(
base_optimizer: Type[tf.keras.optimizers.Optimizer],
) -> Type[tf.keras.optimizers.Optimizer]:
base_optimizer: Type[keras_legacy_optimizer],
) -> Type[keras_legacy_optimizer]:
"""Factory function returning an optimizer class with decoupled weight
decay.

Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/yogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typeguard import typechecked
from typing import Union, Callable

Expand All @@ -50,7 +51,7 @@ def _solve(a, b, c):


@tf.keras.utils.register_keras_serializable(package="Addons")
class Yogi(tf.keras.optimizers.Optimizer):
class Yogi(KerasLegacyOptimizer):
"""Optimizer that implements the Yogi algorithm in Keras.

See Algorithm 2 of
Expand Down
Loading