-
Notifications
You must be signed in to change notification settings - Fork 617
Point optimizer to tf.keras.optimizer.legacy.Optimizer to be compatib… #2706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
3de5fca
3bdda5e
27a9a76
ca68189
7c8c291
1b16fac
13c58f6
743c51d
a66a595
69bcdb4
e734793
8ecfab5
82554de
268d37a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,12 @@ | |
| import warnings | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS | ||
| from tensorflow_addons.utils import types | ||
|
|
||
| from typeguard import typechecked | ||
|
|
||
|
|
||
| class AveragedOptimizerWrapper(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta): | ||
| class AveragedOptimizerWrapper(BASE_OPTIMIZER_CLASS, metaclass=abc.ABCMeta): | ||
| @typechecked | ||
| def __init__( | ||
| self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs | ||
|
|
@@ -32,10 +32,20 @@ def __init__( | |
| if isinstance(optimizer, str): | ||
| optimizer = tf.keras.optimizers.get(optimizer) | ||
|
|
||
| if not isinstance(optimizer, tf.keras.optimizers.Optimizer): | ||
| raise TypeError( | ||
| "optimizer is not an object of tf.keras.optimizers.Optimizer" | ||
| ) | ||
| if tf.__version__[:3] <= "2.8": | ||
|
||
| if not isinstance(optimizer, tf.keras.optimizers.Optimizer): | ||
| raise TypeError( | ||
| "optimizer is not an object of tf.keras.optimizers.Optimizer." | ||
| ) | ||
| else: | ||
| if not isinstance( | ||
| optimizer, | ||
| (tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer), | ||
| ): | ||
| raise TypeError( | ||
| "optimizer is not an object of tf.keras.optimizers.legacy.Optimizer " | ||
| "or tf.keras.optimizers.Optimizer." | ||
| ) | ||
|
|
||
| self._optimizer = optimizer | ||
| self._track_trackable(self._optimizer, "awg_optimizer") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,13 @@ | |
| import tensorflow as tf | ||
| from tensorflow_addons.utils.types import FloatTensorLike | ||
|
|
||
| from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS | ||
|
||
| from typeguard import typechecked | ||
| from typing import Union, Callable | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="Addons") | ||
| class ConditionalGradient(tf.keras.optimizers.Optimizer): | ||
| class ConditionalGradient(BASE_OPTIMIZER_CLASS): | ||
| """Optimizer that implements the Conditional Gradient optimization. | ||
| This optimizer helps handle constraints well. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| import tensorflow as tf | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing copyright header There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
|
||
| BASE_OPTIMIZER_CLASS = tf.keras.optimizers.legacy.Optimizer | ||
|
||
| if tf.__version__[:3] <= "2.8": | ||
|
||
| BASE_OPTIMIZER_CLASS = tf.keras.optimizers.Optimizer | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,14 @@ | |
| from typing import Union, Callable | ||
|
|
||
|
|
||
| if tf.__version__[:3] > "2.8": | ||
|
||
| adam_optimizer_class = tf.keras.optimizers.legacy.Adam | ||
| else: | ||
| adam_optimizer_class = tf.keras.optimizers.Adam | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="Addons") | ||
| class LazyAdam(tf.keras.optimizers.Adam): | ||
| class LazyAdam(adam_optimizer_class): | ||
| """Variant of the Adam optimizer that handles sparse updates more | ||
| efficiently. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,11 +16,12 @@ | |
| import tensorflow as tf | ||
| from tensorflow_addons.utils import types | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't required to update types.Optimizer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! done |
||
|
|
||
| from tensorflow_addons.optimizers import BASE_OPTIMIZER_CLASS | ||
| from typeguard import typechecked | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="Addons") | ||
| class Lookahead(tf.keras.optimizers.Optimizer): | ||
| class Lookahead(BASE_OPTIMIZER_CLASS): | ||
| """This class allows to extend optimizers with the lookahead mechanism. | ||
|
|
||
| The mechanism is proposed by Michael R. Zhang et.al in the paper | ||
|
|
@@ -71,10 +72,20 @@ def __init__( | |
|
|
||
| if isinstance(optimizer, str): | ||
| optimizer = tf.keras.optimizers.get(optimizer) | ||
| if not isinstance(optimizer, tf.keras.optimizers.Optimizer): | ||
| raise TypeError( | ||
| "optimizer is not an object of tf.keras.optimizers.Optimizer" | ||
| ) | ||
| if tf.__version__[:3] <= "2.8": | ||
| if not isinstance(optimizer, tf.keras.optimizers.Optimizer): | ||
| raise TypeError( | ||
| "optimizer is not an object of tf.keras.optimizers.Optimizer." | ||
| ) | ||
| else: | ||
| if not isinstance( | ||
| optimizer, | ||
| (tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer), | ||
| ): | ||
| raise TypeError( | ||
| "optimizer is not an object of tf.keras.optimizers.legacy.Optimizer " | ||
| "or tf.keras.optimizers.Optimizer." | ||
| ) | ||
|
|
||
| self._optimizer = optimizer | ||
| self._set_hyper("sync_period", sync_period) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -401,13 +401,17 @@ def test_var_list_with_exclude_list_sgdw(dtype): | |
| ) | ||
|
|
||
|
|
||
| if tf.__version__[:3] > "2.8": | ||
|
||
| optimizer_class = tf.keras.optimizers.legacy.SGD | ||
| else: | ||
| optimizer_class = tf.keras.optimizers.SGD | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "optimizer", | ||
| [ | ||
| weight_decay_optimizers.SGDW, | ||
| weight_decay_optimizers.extend_with_decoupled_weight_decay( | ||
| tf.keras.optimizers.SGD | ||
| ), | ||
| weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,10 +261,17 @@ def _do_use_weight_decay(self, var): | |
| return var.ref() in self._decay_var_list | ||
|
|
||
|
|
||
| optimizer_class = Union[ | ||
|
||
| tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer | ||
| ] | ||
| if tf.__version__[:3] <= "2.8": | ||
| optimizer_class = tf.keras.optimizers.Optimizer | ||
|
|
||
|
|
||
| @typechecked | ||
| def extend_with_decoupled_weight_decay( | ||
| base_optimizer: Type[tf.keras.optimizers.Optimizer], | ||
| ) -> Type[tf.keras.optimizers.Optimizer]: | ||
| base_optimizer: Type[optimizer_class], | ||
| ) -> Type[optimizer_class]: | ||
| """Factory function returning an optimizer class with decoupled weight | ||
| decay. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this breaks alphabetic order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This order matters, if this line does not go before other optimizers it creates a cyclic importing.