Skip to content
Merged
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Additional optimizers that conform to Keras API."""

from tensorflow_addons.optimizers.constants import BASE_OPTIMIZER_CLASS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks alphabetic order

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This order matters, if this line does not go before other optimizers it creates a cyclic importing.

from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper
from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
from tensorflow_addons.optimizers.cyclical_learning_rate import CyclicalLearningRate
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from typing import Union, Callable, Dict


@tf.keras.utils.register_keras_serializable(package="Addons")
class AdaBelief(tf.keras.optimizers.Optimizer):
class AdaBelief(BASE_OPTIMIZER_CLASS):
"""Variant of the Adam optimizer.

It achieves fast convergence as Adam and generalization comparable to SGD.
Expand Down
22 changes: 16 additions & 6 deletions tensorflow_addons/optimizers/average_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import warnings

import tensorflow as tf
from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from tensorflow_addons.utils import types

from typeguard import typechecked


class AveragedOptimizerWrapper(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta):
class AveragedOptimizerWrapper(BASE_OPTIMIZER_CLASS, metaclass=abc.ABCMeta):
@typechecked
def __init__(
self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs
Expand All @@ -32,10 +32,20 @@ def __init__(
if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)

if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer"
)
if tf.__version__[:3] <= "2.8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this condition? What if we would write (tf.keras.optimizers.Optimizer, BASE_OPTIMIZER_CLASS)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this for the error message to be accurate. but rethinking about it, I am doing (tf.keras.optimizers.Optimizer, BASE_OPTIMIZER_CLASS) check and imply that after 2.9 you should expect tf.keras.optimizers.legacy.Optimizer.

if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer."
)
else:
if not isinstance(
optimizer,
(tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer),
):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.legacy.Optimizer "
"or tf.keras.optimizers.Optimizer."
)

self._optimizer = optimizer
self._track_trackable(self._optimizer, "awg_optimizer")
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/optimizers/cocob.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from typeguard import typechecked
import tensorflow as tf

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS


@tf.keras.utils.register_keras_serializable(package="Addons")
class COCOB(tf.keras.optimizers.Optimizer):
class COCOB(BASE_OPTIMIZER_CLASS):
"""Optimizer that implements COCOB Backprop Algorithm

Reference:
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/conditional_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line should be added after line 17

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from typeguard import typechecked
from typing import Union, Callable


@tf.keras.utils.register_keras_serializable(package="Addons")
class ConditionalGradient(tf.keras.optimizers.Optimizer):
class ConditionalGradient(BASE_OPTIMIZER_CLASS):
"""Optimizer that implements the Conditional Gradient optimization.
This optimizer helps handle constraints well.
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_addons/optimizers/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copyright header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


BASE_OPTIMIZER_CLASS = tf.keras.optimizers.legacy.Optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not if ... else? It will crash if tf.keras.optimizers.legacy doesn't exist (TF 2.8), right?

Are you sure it has to be CAPITAL_LETTER?

I'd rename it to KerasLegacyOptimizer (or KERAS_LEGACY_OPTIMIZER - I'm not an expert in conventions 😄 ) and rename the file to keras.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I changed the check condition.

For the naming, yea, we probably want to do camel case since it is a class itself. But I don't want to imply Legacy in the name, which is not yet 100% correct - we still use tf.keras.optimizers.Optimizer in many places. I am renaming to BaseOptimizerClass, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub - I expect that after refactor in keras is done, we will be gradually switching back to tf.keras.optimizers.Optimizer. BaseOptimizerClass will be confusing then because it'll be name for the legacy class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, renamed to KerasLegacyOptimizer.

if tf.__version__[:3] <= "2.8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I wrote previously:

>>> tf.__version__
'2.10.0-dev20220531'
>>> tf.__version__[:3]
'2.1'
>>> tf.__version__[:3] > "2.8"
False

Please fix this condition (also in other files) or just check if tf.keras.optimizers.legacy exists, maybe code would be cleaner then...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! done

BASE_OPTIMIZER_CLASS = tf.keras.optimizers.Optimizer
8 changes: 4 additions & 4 deletions tensorflow_addons/optimizers/cyclical_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down Expand Up @@ -286,7 +286,7 @@ def __init__(
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from typing import List, Union

import tensorflow as tf

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class MultiOptimizer(tf.keras.optimizers.Optimizer):
class MultiOptimizer(BASE_OPTIMIZER_CLASS):
"""Multi Optimizer Wrapper for Discriminative Layer Training.

Creates a wrapper around a set of instantiated optimizer layer pairs.
Expand All @@ -30,7 +32,7 @@ class MultiOptimizer(tf.keras.optimizers.Optimizer):
Each optimizer will optimize only the weights associated with its paired layer.
This can be used to implement discriminative layer training by assigning
different learning rates to each optimizer layer pair.
`(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
`(tf.keras.optimizers.legacy.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
Please note that the layers must be instantiated before instantiating the optimizer.

Args:
Expand Down Expand Up @@ -130,7 +132,7 @@ def get_config(self):
@classmethod
def create_optimizer_spec(
cls,
optimizer: tf.keras.optimizers.Optimizer,
optimizer: BASE_OPTIMIZER_CLASS,
layers_or_model: Union[
tf.keras.Model,
tf.keras.Sequential,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from typeguard import typechecked

import tensorflow as tf
from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes


@tf.keras.utils.register_keras_serializable(package="Addons")
class LAMB(tf.keras.optimizers.Optimizer):
class LAMB(BASE_OPTIMIZER_CLASS):
"""Optimizer that implements the Layer-wise Adaptive Moments (LAMB).

See paper [Large Batch Optimization for Deep Learning: Training BERT
Expand Down
8 changes: 7 additions & 1 deletion tensorflow_addons/optimizers/lazy_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@
from typing import Union, Callable


if tf.__version__[:3] > "2.8":
Copy link
Member

@fsx950223 fsx950223 Jun 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use from packaging.version import Version

adam_optimizer_class = tf.keras.optimizers.legacy.Adam
else:
adam_optimizer_class = tf.keras.optimizers.Adam


@tf.keras.utils.register_keras_serializable(package="Addons")
class LazyAdam(tf.keras.optimizers.Adam):
class LazyAdam(adam_optimizer_class):
"""Variant of the Adam optimizer that handles sparse updates more
efficiently.
Expand Down
21 changes: 16 additions & 5 deletions tensorflow_addons/optimizers/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import tensorflow as tf
from tensorflow_addons.utils import types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't required to update types.Optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! done


from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class Lookahead(tf.keras.optimizers.Optimizer):
class Lookahead(BASE_OPTIMIZER_CLASS):
"""This class allows to extend optimizers with the lookahead mechanism.

The mechanism is proposed by Michael R. Zhang et.al in the paper
Expand Down Expand Up @@ -71,10 +72,20 @@ def __init__(

if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer"
)
if tf.__version__[:3] <= "2.8":
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer."
)
else:
if not isinstance(
optimizer,
(tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer),
):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.legacy.Optimizer "
"or tf.keras.optimizers.Optimizer."
)

self._optimizer = optimizer
self._set_hyper("sync_period", sync_period)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/optimizers/moving_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
r"""Construct a new MovingAverage optimizer.

Args:
optimizer: str or `tf.keras.optimizers.Optimizer` that will be
optimizer: str or `tf.keras.optimizers.legacy.Optimizer` that will be
used to compute and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/optimizers/novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from typing import Union, Callable
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class NovoGrad(tf.keras.optimizers.Optimizer):
class NovoGrad(BASE_OPTIMIZER_CLASS):
"""Optimizer that implements NovoGrad.

The NovoGrad Optimizer was first proposed in [Stochastic Gradient
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/proximal_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from tensorflow_addons.utils.types import FloatTensorLike


@tf.keras.utils.register_keras_serializable(package="Addons")
class ProximalAdagrad(tf.keras.optimizers.Optimizer):
class ProximalAdagrad(BASE_OPTIMIZER_CLASS):
"""Optimizer that implements the Proximal Adagrad algorithm.

References:
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/rectified_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from typing import Union, Callable, Dict
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class RectifiedAdam(tf.keras.optimizers.Optimizer):
class RectifiedAdam(BASE_OPTIMIZER_CLASS):
"""Variant of the Adam optimizer whose adaptive learning rate is rectified
so as to have a consistent variance.

Expand Down
7 changes: 3 additions & 4 deletions tensorflow_addons/optimizers/tests/standard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf

from tensorflow_addons import optimizers
from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from tensorflow_addons.utils.test_utils import discover_classes

class_exceptions = [
Expand All @@ -29,12 +30,10 @@
"ConditionalGradient", # is wrapper
"Lookahead", # is wrapper
"MovingAverage", # is wrapper
"BASE_OPTIMIZER_CLASS", # is a constantc
]


classes_to_test = discover_classes(
optimizers, tf.keras.optimizers.Optimizer, class_exceptions
)
classes_to_test = discover_classes(optimizers, BASE_OPTIMIZER_CLASS, class_exceptions)


@pytest.mark.parametrize("optimizer", classes_to_test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,17 @@ def test_var_list_with_exclude_list_sgdw(dtype):
)


if tf.__version__[:3] > "2.8":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use from packaging.version import Version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I am switching to check if "optimizers.legacy" exists to keep consistency.

optimizer_class = tf.keras.optimizers.legacy.SGD
else:
optimizer_class = tf.keras.optimizers.SGD


@pytest.mark.parametrize(
"optimizer",
[
weight_decay_optimizers.SGDW,
weight_decay_optimizers.extend_with_decoupled_weight_decay(
tf.keras.optimizers.SGD
),
weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class),
],
)
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
Expand Down
11 changes: 9 additions & 2 deletions tensorflow_addons/optimizers/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,17 @@ def _do_use_weight_decay(self, var):
return var.ref() in self._decay_var_list


optimizer_class = Union[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not BaseOptimizerClass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am renaming this to keras_legacy_optimizer to keep aligned with changes suggested by Michal.

tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer
]
if tf.__version__[:3] <= "2.8":
optimizer_class = tf.keras.optimizers.Optimizer


@typechecked
def extend_with_decoupled_weight_decay(
base_optimizer: Type[tf.keras.optimizers.Optimizer],
) -> Type[tf.keras.optimizers.Optimizer]:
base_optimizer: Type[optimizer_class],
) -> Type[optimizer_class]:
"""Factory function returning an optimizer class with decoupled weight
decay.

Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/optimizers/yogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike

from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS
from typeguard import typechecked
from typing import Union, Callable

Expand All @@ -50,7 +51,7 @@ def _solve(a, b, c):


@tf.keras.utils.register_keras_serializable(package="Addons")
class Yogi(tf.keras.optimizers.Optimizer):
class Yogi(BASE_OPTIMIZER_CLASS):
"""Optimizer that implements the Yogi algorithm in Keras.

See Algorithm 2 of
Expand Down