diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index a2da66c776..1e49a0f1bf 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -8,6 +8,7 @@ py_library( "__init__.py", "lazy_adam.py", "moving_average.py", + "weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ @@ -40,3 +41,16 @@ py_test( ":optimizers", ], ) + +py_test( + name = "weight_decay_optimizers_test", + size = "small", + srcs = [ + "weight_decay_optimizers_test.py", + ], + main = "weight_decay_optimizers_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 8804ebd69f..89b29d2de9 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -5,12 +5,15 @@ |:---------- |:------------- |:--------------| | lazy_adam | SIG-Addons | addons@tensorflow.org | | moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | +| weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | + ## Components | Submodule | Optimizer | Reference | -|:----------------------- |:---------------------- |:---------| +|:--------- |:---------- |:---------| | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | | moving_average | MovingAverage | | +| weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | ## Contribution Guidelines @@ -18,7 +21,7 @@ In order to conform with the current API standard, all optimizers must: * Inherit from either `keras.optimizer_v2.OptimizerV2` or its subclasses. - * [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py) + * [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/keras_utils.py) so it can be serialized properly. * Add the addon to the `py_library` in this sub-package's BUILD file. diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index 79bbcf04f5..7e189fa954 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,4 +19,8 @@ from __future__ import print_function from tensorflow_addons.optimizers.lazy_adam import LazyAdam +from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW +from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW +from tensorflow_addons.optimizers.weight_decay_optimizers import ( + extend_with_decoupled_weight_decay) from tensorflow_addons.optimizers.moving_average import MovingAverage diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py new file mode 100644 index 0000000000..4759538d9a --- /dev/null +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -0,0 +1,425 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two + examples used in the above paper (SGDW and AdamW), but in general this + can extend any OptimizerX by using + `extend_with_decoupled_weight_decay( + OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamW, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note: this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + """ + + def __init__(self, weight_decay, **kwargs): + """Extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by + which a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + wd = kwargs.pop('weight_decay', weight_decay) + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + self._decay_var_list = None # is set in minimize or apply_gradients + self._set_hyper('weight_decay', wd) + + def get_config(self): + config = super(DecoupledWeightDecayExtension, self).get_config() + config.update({ + 'weight_decay': + self._serialize_hyperparameter('weight_decay'), + }) + return config + + def minimize(self, + loss, + var_list, + grad_loss=None, + name=None, + decay_var_list=None): + """Minimize `loss` by updating `var_list`. + + This method simply computes gradient using `tf.GradientTape` and calls + `apply_gradients()`. If you want to process the gradient before + applying then call `tf.GradientTape` and `apply_gradients()` explicitly + instead of using this function. + + Args: + loss: A callable taking no arguments which returns the value to + minimize. + var_list: list or tuple of `Variable` objects to update to + minimize `loss`, or a callable returning the list or tuple of + `Variable` objects. Use callable when the variable list would + otherwise be incomplete before `minimize` since the variables + are created at the first time `loss` is called. + grad_loss: Optional. A `Tensor` holding the gradient computed for + `loss`. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. + name: Optional name for the returned operation. + Returns: + An Operation that updates the variables in `var_list`. If + `global_step` was not `None`, that operation also increments + `global_step`. + Raises: + ValueError: If some of the variables are not `Variable` objects. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, var_list=var_list, grad_loss=grad_loss, name=name) + + def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. + Returns: + An `Operation` that applies the specified gradients. If + `global_step` was not None, that operation also increments + `global_step`. + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, name=name) + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub( + self._get_hyper('weight_decay', var.dtype) * var, + self._use_locking) + return tf.no_op() + + def _decay_weights_sparse_op(self, var, indices): + if not self._decay_var_list or var in self._decay_var_list: + update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather( + var, indices)) + return self._resource_scatter_add(var, indices, update) + return tf.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + + def _resource_apply_dense(self, grad, var): + with tf.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, + self)._resource_apply_dense(grad, var) + + def _resource_apply_sparse(self, grad, var, indices): + decay_op = self._decay_weights_sparse_op(var, indices) + with tf.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, + self)._resource_apply_sparse(grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight + decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is + equivalent to `tfa.optimizers.AdamW`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + # update var1, var2 but only decay var1 + optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) + + Note: this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of 'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + + Note: you might want to register your own custom optimizer using + `tf.keras.utils.get_custom_objects()`. + + Args: + base_optimizer: An optimizer class that inherits from + tf.optimizers.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being + decoupled from the optimization steps w.r.t. to the loss + function, as described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this + simplifies hyperparameter search since it decouples the settings + of weight decay and learning rate. For adaptive gradient + algorithms, it regularizes variables with large gradients more + than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + + return OptimizerWithDecoupledWeightDecay + + +@keras_utils.register_keras_custom_object +class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Decoupled + Weight Decay Regularization" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `tf.keras.optimizers.SGD` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the SGD Optimizer. + + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.SGDW( + learning_rate=lr, weight_decay=wd, momentum=0.9) + ``` + """ + + def __init__(self, + weight_decay, + learning_rate=0.001, + momentum=0.0, + nesterov=False, + name='SGDW', + **kwargs): + """Construct a new SGDW optimizer. + + For further information see the documentation of the SGD Optimizer. + + Args: + learning_rate: float hyperparameter >= 0. Learning rate. + momentum: float hyperparameter >= 0 that accelerates SGD in the + relevant direction and dampens oscillations. + nesterov: boolean. Whether to apply Nesterov momentum. + name: Optional name prefix for the operations created when applying + gradients. Defaults to 'SGD'. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse decay + of learning rate. `lr` is included for backward compatibility, + recommended to use `learning_rate` instead. + """ + super(SGDW, self).__init__( + weight_decay, + learning_rate=learning_rate, + momentum=momentum, + nesterov=nesterov, + name=name, + **kwargs) + + +@keras_utils.register_keras_custom_object +class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Decoupled + Weight Decay Regularization" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `tf.keras.optimizers.Adam` and additionally + decays the variable. Note that this is different from adding L2 + regularization on the variables to the loss: it regularizes variables with + large gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + """ + + def __init__(self, + weight_decay, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-07, + amsgrad=False, + name="AdamW", + **kwargs): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A Tensor or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning + rate. + beta_1: A float value or a constant float tensor. The exponential + decay rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential + decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just + before Section 2.1), not the epsilon in Algorithm 1 of the + paper. + amsgrad: boolean. Whether to apply AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and + beyond". + name: Optional name for the operations created when applying + gradients. Defaults to "AdamW". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse decay + of learning rate. `lr` is included for backward compatibility, + recommended to use `learning_rate` instead. + """ + super(AdamW, self).__init__( + weight_decay, + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py new file mode 100644 index 0000000000..e265eecb3c --- /dev/null +++ b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py @@ -0,0 +1,276 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.utils import test_utils +from tensorflow_addons.optimizers import weight_decay_optimizers + +WEIGHT_DECAY = 0.01 + + +class OptimizerTestBase(tf.test.TestCase): + """Base class for optimizer tests. + + Optimizer tests may inherit from this class and define test + functions using doTest. Usually this should include the functions + testSparse, testBasic, and testBasicCallableParams. See + weight_decay_optimizers_test for an example. + """ + + def doTest(self, optimizer, update_fn, do_sparse=False, + **optimizer_kwargs): + """The major test function. + + Args: + optimizer: The tensorflow optimizer class to be tested. + update_fn: The numpy update function of the optimizer, the function + signature must be + update_fn(var: np.array, + grad_t: np.array, + slot_vars: dict, + **kwargs) -> (updated_var, updated_slot_vars) + Note that slot_vars will be initialized to an empty dictionary + for each variable, initial values should be handled in the + update_fn. + do_sparse: If True, test sparse update. Defaults to False, i.e., + dense update. + **optimizer_kwargs:The parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + """ + for i, dtype in enumerate([tf.half, tf.float32, tf.float64]): + # Initialize variables for numpy implementation. + np_slot_vars0, np_slot_vars1 = {}, {} + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + # Create Tensorflow variables. + var0 = tf.Variable(var0_np, name="var0_%d" % i) + var1 = tf.Variable(var1_np, name="var1_%d" % i) + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np), tf.constant(grads0_np_indices), + tf.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np), tf.constant(grads1_np_indices), + tf.constant([2])) + else: + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = optimizer(**optimizer_kwargs) + # Validate initial values. + if not tf.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Create the update op. + # Run 3 steps of the optimizer + for _ in range(3): + if tf.executing_eagerly(): + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + else: + self.evaluate(update) + var0_np, np_slot_vars0 = update_fn( + var0_np, grads0_np, np_slot_vars0, **optimizer_kwargs) + var1_np, np_slot_vars1 = update_fn( + var1_np, grads1_np, np_slot_vars1, **optimizer_kwargs) + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs): + """Test for repeated indices in sparse updates. + + This test verifies that an update with repeated indices is the same as + an update with two times the gradient. + + Args: + optimizer: The tensorflow optimizer class to be tested. + **optimizer_kwargs: The parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + """ + for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: + repeated_index_update_var = tf.Variable([[1.0], [2.0]], + dtype=dtype) + aggregated_update_var = tf.Variable([[1.0], [2.0]], dtype=dtype) + grad_repeated_index = tf.IndexedSlices( + tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + tf.constant([1, 1]), tf.constant([2, 1])) + grad_aggregated = tf.IndexedSlices( + tf.constant([0.2], shape=[1, 1], dtype=dtype), + tf.constant([1]), tf.constant([2, 1])) + opt_repeated = optimizer(**optimizer_kwargs) + repeated_update = opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated = optimizer(**optimizer_kwargs) + aggregated_update = opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + if not tf.executing_eagerly(): + self.evaluate(repeated_update) + self.evaluate(aggregated_update) + else: + opt_repeated.apply_gradients([(grad_repeated_index, + repeated_index_update_var)]) + opt_aggregated.apply_gradients([(grad_aggregated, + aggregated_update_var)]) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + + +def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2, + epsilon, weight_decay): + """Numpy update function for AdamW.""" + lr, beta1, beta2, eps, wd = (v() if callable(v) else v + for v in (learning_rate, beta_1, beta_2, + epsilon, weight_decay)) + t = slot_vars.get("t", 0) + 1 + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t + slot_vars["v"] = beta2 * slot_vars.get("v", 0) + (1 - beta2) * grad_t**2 + param_t = (param * (1 - wd) - + lr_t * slot_vars["m"] / (np.sqrt(slot_vars["v"]) + eps)) + slot_vars["t"] = t + return param_t, slot_vars + + +def sgdw_update_numpy(param, grad_t, slot_vars, learning_rate, momentum, + weight_decay): + """Numpy update function for SGDW.""" + m = slot_vars.get("m", 0) + lr, momentum, wd = (v() if callable(v) else v + for v in (learning_rate, momentum, weight_decay)) + slot_vars["m"] = momentum * m + grad_t + param_t = param * (1 - wd) - lr * slot_vars["m"] + return param_t, slot_vars + + +class AdamWTest(OptimizerTestBase): + + optimizer = weight_decay_optimizers.AdamW + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testSparse(self): + self.doTest( + self.optimizer, + adamw_update_numpy, + do_sparse=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testSparseRepeatedIndices(self): + self.doTestSparseRepeatedIndices( + self.optimizer, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testBasic(self): + self.doTest( + self.optimizer, + adamw_update_numpy, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) + + def testBasicCallableParams(self): + self.doTest( + self.optimizer, + adamw_update_numpy, + learning_rate=lambda: 0.001, + beta_1=lambda: 0.9, + beta_2=lambda: 0.999, + epsilon=lambda: 1e-8, + weight_decay=lambda: WEIGHT_DECAY) + + +class SGDWTest(OptimizerTestBase): + + optimizer = weight_decay_optimizers.SGDW + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testSparse(self): + self.doTest( + self.optimizer, + sgdw_update_numpy, + do_sparse=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testSparseRepeatedIndices(self): + self.doTestSparseRepeatedIndices( + self.optimizer, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) + + @test_utils.run_in_graph_and_eager_modes(reset_test=True) + def testBasic(self): + self.doTest( + self.optimizer, + sgdw_update_numpy, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) + + def testBasicCallableParams(self): + self.doTest( + self.optimizer, + sgdw_update_numpy, + learning_rate=lambda: 0.001, + momentum=lambda: 0.9, + weight_decay=lambda: WEIGHT_DECAY) + + +class ExtendWithWeightDecayTest(SGDWTest): + """Verify that the factory function SGDW is the same as SGDW.""" + + optimizer = weight_decay_optimizers.extend_with_decoupled_weight_decay( + tf.optimizers.SGD) + + +if __name__ == "__main__": + tf.test.main()