From 93450a2a054b4bf194d03c38002bd9ff7a7e1229 Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Tue, 9 Apr 2019 08:38:34 +0200 Subject: [PATCH 1/8] Add decoupled weight decay optimizers and helper class for optimizer tests. --- tensorflow_addons/optimizers/BUILD | 16 + .../optimizers/optimizer_test_base.py | 154 ++++++++ .../optimizers/weight_decay_optimizers.py | 362 ++++++++++++++++++ .../weight_decay_optimizers_test.py | 130 +++++++ 4 files changed, 662 insertions(+) create mode 100644 tensorflow_addons/optimizers/optimizer_test_base.py create mode 100644 tensorflow_addons/optimizers/weight_decay_optimizers.py create mode 100644 tensorflow_addons/optimizers/weight_decay_optimizers_test.py diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 0fcb8088c5..3e12a01edb 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "lazy_adam.py", + "weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ @@ -26,3 +27,18 @@ py_test( ":optimizers", ], ) + + +py_test( + name = "weight_decay_optimizers_test", + size = "small", + srcs = [ + "weight_decay_optimizers_test.py", + "optimizer_test_base.py", + ], + main = "weight_decay_optimizers_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) diff --git a/tensorflow_addons/optimizers/optimizer_test_base.py b/tensorflow_addons/optimizers/optimizer_test_base.py new file mode 100644 index 0000000000..b33f63ee6f --- /dev/null +++ b/tensorflow_addons/optimizers/optimizer_test_base.py @@ -0,0 +1,154 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for optimizer tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables + + +class OptimizerTestBase(tf.test.TestCase): + """Base class for optimizer tests. + + Optimizer tests may inherit from this class and define test functions + using doTest. Usually this should include the functions testSparse, + testBasic, and testBasicCallableParams. See weight_decay_optimizers_test for + an example. + """ + + def doTest(self, optimizer, update_fn, params, do_sparse=False): + """The major test function. + + Args: + optimizer: The tensorflow optimizer class to be tested. + update_fn: The numpy update function of the optimizer, the function + signature must be + update_fn(var: np.array, + grad_t: np.array, + slot_vars: dict, + optimizer_params: dict) -> updated_var, updated_slot_vars + Note that slot_vars will be initialized to an empty dictionary for + each variable, initial values should be handled in the update_fn. + params: A dict, the parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to the + optimizer_params in the update_fn. + do_sparse: If True, test sparse update. Defaults to False, i.e., dense + update. + """ + for i, dtype in enumerate([tf.half, tf.float32, tf.float64]): + with self.session(graph=tf.Graph()): + # Initialize variables for numpy implementation. + np_slot_vars0, np_slot_vars1 = {}, {} + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + # Create Tensorflow variables. + var0 = tf.Variable(var0_np, name="var0_%d" % i) + var1 = tf.Variable(var1_np, name="var1_%d" % i) + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np), tf.constant(grads0_np_indices), + tf.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np), tf.constant(grads1_np_indices), + tf.constant([2])) + else: + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = optimizer(**params) + # Validate initial values. + if not tf.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Create the update op. + # Run 3 steps of the optimizer + for _ in range(3): + if tf.executing_eagerly(): + opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + else: + self.evaluate(update) + var0_np, np_slot_vars0 = update_fn(var0_np, grads0_np, + np_slot_vars0, params) + var1_np, np_slot_vars1 = update_fn(var1_np, grads1_np, + np_slot_vars1, params) + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def doTestSparseRepeatedIndices(self, optimizer, params): + """Test for repeated indices in sparse updates. + + This test verifies that an update with repeated indices is the same as + an update with two times the gradient. + + Args: + optimizer: The tensorflow optimizer class to be tested. + params: A dict, the parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to the + optimizer_params in the update_fn. + """ + for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: + with self.cached_session(): + repeated_index_update_var = tf.Variable([[1.0], [2.0]], + dtype=dtype) + aggregated_update_var = tf.Variable([[1.0], [2.0]], + dtype=dtype) + grad_repeated_index = tf.IndexedSlices( + tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + tf.constant([1, 1]), tf.constant([2, 1])) + grad_aggregated = tf.IndexedSlices( + tf.constant([0.2], shape=[1, 1], dtype=dtype), + tf.constant([1]), tf.constant([2, 1])) + opt_repeated = optimizer(**params) + repeated_update = opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated = optimizer(**params) + aggregated_update = opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + if not tf.executing_eagerly(): + self.evaluate(repeated_update) + self.evaluate(aggregated_update) + else: + opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py new file mode 100644 index 0000000000..ff16c2c355 --- /dev/null +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -0,0 +1,362 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + schedule = tf.train.piecewise_constant(tf.train.get_global_step(), + [10000, 15000], [1e-0, 1e-1, 1e-2]) + lr = 1e-1 * schedule() + wd = lambda: 1e-4 * schedule() + + # ... + + optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, + weight_decay=wd, + momentum=0.9, + use_nesterov=True) + ``` + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + self._decay_var_list = None # is set in minimize or apply_gradients + self._set_hyper('weight_decay', kwargs.get('weight_decay', + weight_decay)) + + def get_config(self): + config = super(DecoupledWeightDecayExtension, self).get_config() + config.update({ + 'weight_decay': + self._serialize_hyperparameter('weight_decay'), + }) + return config + + def minimize(self, + loss, + var_list, + grad_loss=None, + name=None, + decay_var_list=None): + """Minimize `loss` by updating `var_list`. + + This method simply computes gradient using `tf.GradientTape` and calls + `apply_gradients()`. If you want to process the gradient before applying + then call `tf.GradientTape` and `apply_gradients()` explicitly instead + of using this function. + + Args: + loss: A callable taking no arguments which returns the value to minimize. + var_list: list or tuple of `Variable` objects to update to minimize + `loss`, or a callable returning the list or tuple of `Variable` objects. + Use callable when the variable list would otherwise be incomplete before + `minimize` since the variables are created at the first time `loss` is + called. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of variables to be decayed. Defaults to all + variables in var_list. + name: Optional name for the returned operation. + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + Raises: + ValueError: If some of the variables are not `Variable` objects. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, var_list=var_list, grad_loss=grad_loss, name=name) + + def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): + """Apply gradients to variables. + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + Args: + grads_and_vars: List of (gradient, variable) pairs. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. + decay_var_list: Optional list of variables to be decayed. Defaults to all + variables in var_list. + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, name=name) + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub( + self._get_hyper('weight_decay', var.dtype) * var, + self._use_locking) + return tf.no_op() + + def _decay_weights_sparse_op(self, var, indices): + if not self._decay_var_list or var in self._decay_var_list: + update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather( + var, indices)) + return self._resource_scatter_add(var, indices, update) + return tf.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + + def _resource_apply_dense(self, grad, var): + with tf.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, + self)._resource_apply_dense(grad, var) + + def _resource_apply_sparse(self, grad, var, indices): + decay_op = self._decay_weights_sparse_op(var, indices) + with tf.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, + self)._resource_apply_sparse(grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of 'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + @keras_utils.register_keras_custom_object + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@keras_utils.register_keras_custom_object +class SGDWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the SGD Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, + weight_decay, + learning_rate=0.001, + momentum=0.0, + nesterov=False, + name='SGDW', + **kwargs): + """Construct a new SGDW optimizer. + + For further information see the documentation of the SGD Optimizer. + + Args: + Arguments: + learning_rate: float hyperparameter >= 0. Learning rate. + momentum: float hyperparameter >= 0 that accelerates SGD in the relevant + direction and dampens oscillations. + nesterov: boolean. Whether to apply Nesterov momentum. + name: Optional name prefix for the operations created when applying + gradients. Defaults to 'SGD'. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, + `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip + gradients by value, `decay` is included for backward compatibility to + allow time inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + super(SGDWOptimizer, self).__init__( + weight_decay, + learning_rate=learning_rate, + momentum=momentum, + nesterov=nesterov, + name=name, + **kwargs) + + +@keras_utils.register_keras_custom_object +class AdamWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.keras.optimizers.SGD, weight_decay=weight_decay) + ``` + """ + + def __init__(self, + weight_decay, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-07, + amsgrad=False, + name="AdamW", + **kwargs): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A Tensor or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from + the paper "On the Convergence of Adam and beyond". + name: Optional name for the operations created when applying gradients. + Defaults to "AdamW". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, + `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip + gradients by value, `decay` is included for backward compatibility to + allow time inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + super(AdamWOptimizer, self).__init__( + weight_decay, + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py new file mode 100644 index 0000000000..612a2f4316 --- /dev/null +++ b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py @@ -0,0 +1,130 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import test_util + +from tensorflow_addons.optimizers.optimizer_test_base import OptimizerTestBase +from tensorflow_addons.optimizers import weight_decay_optimizers + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, grad_t, slot_vars, optimizer_params): + """Numpy update function for AdamW.""" + opt_params = ( + optimizer_params[k] for k in + ["learning_rate", "beta_1", "beta_2", "epsilon", "weight_decay"]) + lr, beta1, beta2, eps, wd = (v() if callable(v) else v for v in opt_params) + t = slot_vars.get("t", 0) + 1 + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t + slot_vars["v"] = beta2 * slot_vars.get("v", 0) + (1 - beta2) * grad_t**2 + param_t = (param * (1 - wd) - + lr_t * slot_vars["m"] / (np.sqrt(slot_vars["v"]) + eps)) + slot_vars["t"] = t + return param_t, slot_vars + + +def sgdw_update_numpy(param, grad_t, slot_vars, optimizer_params): + """Numpy update function for SGDW.""" + m = slot_vars.get("m", 0) + optimizer_params = { + k: v() if callable(v) else v + for k, v in optimizer_params.items() + } + slot_vars["m"] = optimizer_params["momentum"] * m + grad_t + lr = optimizer_params["learning_rate"] + wd = optimizer_params["weight_decay"] + param_t = param * (1 - wd) - lr * slot_vars["m"] + return param_t, slot_vars + + +class AdamWOptimizerTest(OptimizerTestBase): + + opt_params = { + "learning_rate": 0.001, + "beta_1": 0.9, + "beta_2": 0.999, + "epsilon": 1e-8, + "weight_decay": WEIGHT_DECAY + } + callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} + optimizer = weight_decay_optimizers.AdamWOptimizer + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testSparse(self): + self.doTest( + self.optimizer, + adamw_update_numpy, + self.opt_params, + do_sparse=True) + + @test_util.run_in_graph_and_eager_modes + def testSparseRepeatedIndices(self): + self.doTestSparseRepeatedIndices(self.optimizer, self.opt_params) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testBasic(self): + self.doTest(self.optimizer, adamw_update_numpy, self.opt_params) + + def testBasicCallableParams(self): + self.doTest(self.optimizer, adamw_update_numpy, + self.callable_opt_params) + + +class SGDWOptimizerTest(OptimizerTestBase): + + opt_params = { + "learning_rate": 0.001, + "momentum": 0.9, + "weight_decay": WEIGHT_DECAY + } + callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} + optimizer = weight_decay_optimizers.SGDWOptimizer + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testSparse(self): + self.doTest( + self.optimizer, sgdw_update_numpy, self.opt_params, do_sparse=True) + + @test_util.run_in_graph_and_eager_modes + def testSparseRepeatedIndices(self): + self.doTestSparseRepeatedIndices(self.optimizer, self.opt_params) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testBasic(self): + self.doTest(self.optimizer, sgdw_update_numpy, self.opt_params) + + def testBasicCallableParams(self): + self.doTest(self.optimizer, sgdw_update_numpy, + self.callable_opt_params) + + +class ExtendWithWeightDecayTest(SGDWOptimizerTest): + """Verify that the factory function SGDW is the same as SGDW.""" + + optimizer = weight_decay_optimizers.extend_with_decoupled_weight_decay( + tf.optimizers.SGD) + + +if __name__ == "__main__": + tf.test.main() From 821ac697beea62eb19639e5d91b38d4a038bcef9 Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Thu, 11 Apr 2019 14:48:46 +0200 Subject: [PATCH 2/8] Adapt README.md. Fix broken link to keras_utils. --- tensorflow_addons/optimizers/README.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 0331e8c55c..27b5054668 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -1,14 +1,16 @@ # Addons - Optimizers ## Maintainers -| Submodule | Maintainers | Contact Info | -|:---------- |:------------- |:--------------| -| lazy_adam | SIG-Addons | addons@tensorflow.org | +| Submodule | Maintainers | Contact Info | +|:----------------------- |:----------- |:-------------- | +| lazy_adam | SIG-Addons | addons@tensorflow.org | +| weight_decay_optimizers | SIG-Addons | addons@tensorflow.org | ## Components | Submodule | Optimizer | Reference | -|:----------------------- |:---------------------- |:---------| +|:--------- |:---------- |:---------| | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | +| weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | ## Contribution Guidelines @@ -16,7 +18,7 @@ In order to conform with the current API standard, all optimizers must: * Inherit from either `keras.optimizer_v2.OptimizerV2` or its subclasses. - * [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py) + * [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/keras_utils.py) so it can be serialized properly. * Add the addon to the `py_library` in this sub-package's BUILD file. @@ -25,6 +27,7 @@ must: `@run_in_graph_and_eager_modes` (for test method) or `run_all_in_graph_and_eager_modes` (for TestCase subclass) decorator. + * Consider inheriting from `OptimizerTestBase`. * Add a `py_test` to this sub-package's BUILD file. #### Documentation Requirements From ec55b294624e6d88e0af783c164de8aa3d103c90 Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Thu, 11 Apr 2019 17:02:32 +0200 Subject: [PATCH 3/8] Remove TF private API calls. --- .../optimizers/optimizer_test_base.py | 7 ++----- .../optimizers/weight_decay_optimizers_test.py | 14 +++++++------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/tensorflow_addons/optimizers/optimizer_test_base.py b/tensorflow_addons/optimizers/optimizer_test_base.py index b33f63ee6f..9f7575d166 100644 --- a/tensorflow_addons/optimizers/optimizer_test_base.py +++ b/tensorflow_addons/optimizers/optimizer_test_base.py @@ -21,9 +21,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables - class OptimizerTestBase(tf.test.TestCase): """Base class for optimizer tests. @@ -81,7 +78,7 @@ def doTest(self, optimizer, update_fn, params, do_sparse=False): if not tf.executing_eagerly(): update = opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Create the update op. @@ -132,7 +129,7 @@ def doTestSparseRepeatedIndices(self, optimizer, params): opt_aggregated = optimizer(**params) aggregated_update = opt_aggregated.apply_gradients( [(grad_aggregated, aggregated_update_var)]) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose( self.evaluate(aggregated_update_var), self.evaluate(repeated_index_update_var)) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py index 612a2f4316..c209f5788a 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py @@ -20,8 +20,8 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework import test_util +from tensorflow_addons.utils import test_utils from tensorflow_addons.optimizers.optimizer_test_base import OptimizerTestBase from tensorflow_addons.optimizers import weight_decay_optimizers @@ -70,7 +70,7 @@ class AdamWOptimizerTest(OptimizerTestBase): callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} optimizer = weight_decay_optimizers.AdamWOptimizer - @test_util.run_in_graph_and_eager_modes(reset_test=True) + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): self.doTest( self.optimizer, @@ -78,11 +78,11 @@ def testSparse(self): self.opt_params, do_sparse=True) - @test_util.run_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes def testSparseRepeatedIndices(self): self.doTestSparseRepeatedIndices(self.optimizer, self.opt_params) - @test_util.run_in_graph_and_eager_modes(reset_test=True) + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): self.doTest(self.optimizer, adamw_update_numpy, self.opt_params) @@ -101,16 +101,16 @@ class SGDWOptimizerTest(OptimizerTestBase): callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} optimizer = weight_decay_optimizers.SGDWOptimizer - @test_util.run_in_graph_and_eager_modes(reset_test=True) + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): self.doTest( self.optimizer, sgdw_update_numpy, self.opt_params, do_sparse=True) - @test_util.run_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes def testSparseRepeatedIndices(self): self.doTestSparseRepeatedIndices(self.optimizer, self.opt_params) - @test_util.run_in_graph_and_eager_modes(reset_test=True) + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): self.doTest(self.optimizer, sgdw_update_numpy, self.opt_params) From aa3d7b0f065abc13c0cf25dce31f34f54fd5ae90 Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Wed, 17 Apr 2019 11:16:36 +0200 Subject: [PATCH 4/8] Add imports to __init__.py --- tensorflow_addons/optimizers/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index 543774e8c7..b22abe2224 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,3 +19,7 @@ from __future__ import print_function from tensorflow_addons.optimizers.lazy_adam import LazyAdam +from tensorflow_addons.optimizers.weight_decay_optimizers import AdamWOptimizer +from tensorflow_addons.optimizers.weight_decay_optimizers import SGDWOptimizer +from tensorflow_addons.optimizers.weight_decay_optimizers import ( + extend_with_decoupled_weight_decay) From 9ebc02bff8ced347efc24288d066978c3883103c Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Thu, 18 Apr 2019 15:53:16 +0200 Subject: [PATCH 5/8] Fix indentation of comments, remove call to tf.test.main from optimizer_test_base. --- .../optimizers/optimizer_test_base.py | 62 ++- .../optimizers/weight_decay_optimizers.py | 423 +++++++++--------- 2 files changed, 248 insertions(+), 237 deletions(-) diff --git a/tensorflow_addons/optimizers/optimizer_test_base.py b/tensorflow_addons/optimizers/optimizer_test_base.py index 9f7575d166..6bef3e9480 100644 --- a/tensorflow_addons/optimizers/optimizer_test_base.py +++ b/tensorflow_addons/optimizers/optimizer_test_base.py @@ -24,32 +24,34 @@ class OptimizerTestBase(tf.test.TestCase): """Base class for optimizer tests. - - Optimizer tests may inherit from this class and define test functions - using doTest. Usually this should include the functions testSparse, - testBasic, and testBasicCallableParams. See weight_decay_optimizers_test for - an example. - """ + + Optimizer tests may inherit from this class and define test + functions using doTest. Usually this should include the functions + testSparse, testBasic, and testBasicCallableParams. See + weight_decay_optimizers_test for an example. + """ def doTest(self, optimizer, update_fn, params, do_sparse=False): """The major test function. - - Args: - optimizer: The tensorflow optimizer class to be tested. - update_fn: The numpy update function of the optimizer, the function - signature must be - update_fn(var: np.array, - grad_t: np.array, - slot_vars: dict, - optimizer_params: dict) -> updated_var, updated_slot_vars - Note that slot_vars will be initialized to an empty dictionary for - each variable, initial values should be handled in the update_fn. - params: A dict, the parameters to pass to the construcor of the - optimizer. Either a constant or a callable. This also passed to the - optimizer_params in the update_fn. - do_sparse: If True, test sparse update. Defaults to False, i.e., dense - update. - """ + + Args: + optimizer: The tensorflow optimizer class to be tested. + update_fn: The numpy update function of the optimizer, the function + signature must be + update_fn(var: np.array, + grad_t: np.array, + slot_vars: dict, + optimizer_params: dict) -> (updated_var, + updated_slot_vars) + Note that slot_vars will be initialized to an empty dictionary + for each variable, initial values should be handled in the + update_fn. + params: A dict, the parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + do_sparse: If True, test sparse update. Defaults to False, i.e., + dense update. + """ for i, dtype in enumerate([tf.half, tf.float32, tf.float64]): with self.session(graph=tf.Graph()): # Initialize variables for numpy implementation. @@ -101,15 +103,15 @@ def doTest(self, optimizer, update_fn, params, do_sparse=False): def doTestSparseRepeatedIndices(self, optimizer, params): """Test for repeated indices in sparse updates. - + This test verifies that an update with repeated indices is the same as an update with two times the gradient. Args: - optimizer: The tensorflow optimizer class to be tested. - params: A dict, the parameters to pass to the construcor of the - optimizer. Either a constant or a callable. This also passed to the - optimizer_params in the update_fn. + optimizer: The tensorflow optimizer class to be tested. + params: A dict, the parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. """ for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: with self.cached_session(): @@ -145,7 +147,3 @@ def doTestSparseRepeatedIndices(self, optimizer, params): self.assertAllClose( self.evaluate(aggregated_update_var), self.evaluate(repeated_index_update_var)) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index ff16c2c355..913243e59b 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -24,60 +24,60 @@ class DecoupledWeightDecayExtension(object): """This class allows to extend optimizers with decoupled weight decay. - It implements the decoupled weight decay described by Loshchilov & Hutter - (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is - decoupled from the optimization steps w.r.t. to the loss function. - For SGD variants, this simplifies hyperparameter search since it decouples - the settings of weight decay and learning rate. - For adaptive gradient algorithms, it regularizes variables with large - gradients more than L2 regularization would, which was shown to yield better - training loss and generalization error in the paper above. - - This class alone is not an optimizer but rather extends existing - optimizers with decoupled weight decay. We explicitly define the two examples - used in the above paper (SGDW and AdamW), but in general this can extend - any OptimizerX by using - `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. - In order for it to work, it must be the first class the Optimizer with - weight decay inherits from, e.g. - - ```python - class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): - def __init__(self, weight_decay, *args, **kwargs): - super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). - ``` - - Note that this extension decays weights BEFORE applying the update based - on the gradient, i.e. this extension only has the desired behaviour for - optimizers which do not depend on the value of'var' in the update step! - - Note: when applying a decay to the learning rate, be sure to manually apply - the decay to the `weight_decay` as well. For example: - - ```python - schedule = tf.train.piecewise_constant(tf.train.get_global_step(), - [10000, 15000], [1e-0, 1e-1, 1e-2]) - lr = 1e-1 * schedule() - wd = lambda: 1e-4 * schedule() - - # ... - - optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, - weight_decay=wd, - momentum=0.9, - use_nesterov=True) - ``` - """ - - def __init__(self, weight_decay, **kwargs): - """Construct the extension class that adds weight decay to an optimizer. + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. - Args: - weight_decay: A `Tensor` or a floating point value, the factor by which - a variable is decayed in the update step. - **kwargs: Optional list or tuple or set of `Variable` objects to - decay. + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two + examples used in the above paper (SGDW and AdamW), but in general this + can extend any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + schedule = tf.train.piecewise_constant( + tf.train.get_global_step(), [10000, 15000], [1e-0, 1e-1, 1e-2]) + lr = 1e-1 * schedule() + wd = lambda: 1e-4 * schedule() + + # ... + + optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, + weight_decay=wd, + momentum=0.9, + use_nesterov=True) + ``` """ + + def __init__(self, weight_decay, **kwargs): + """Extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by + which a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ super(DecoupledWeightDecayExtension, self).__init__(**kwargs) self._decay_var_list = None # is set in minimize or apply_gradients self._set_hyper('weight_decay', kwargs.get('weight_decay', @@ -98,50 +98,56 @@ def minimize(self, name=None, decay_var_list=None): """Minimize `loss` by updating `var_list`. - - This method simply computes gradient using `tf.GradientTape` and calls - `apply_gradients()`. If you want to process the gradient before applying - then call `tf.GradientTape` and `apply_gradients()` explicitly instead - of using this function. - Args: - loss: A callable taking no arguments which returns the value to minimize. - var_list: list or tuple of `Variable` objects to update to minimize - `loss`, or a callable returning the list or tuple of `Variable` objects. - Use callable when the variable list would otherwise be incomplete before - `minimize` since the variables are created at the first time `loss` is - called. - grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. - decay_var_list: Optional list of variables to be decayed. Defaults to all - variables in var_list. - name: Optional name for the returned operation. - Returns: - An Operation that updates the variables in `var_list`. If `global_step` - was not `None`, that operation also increments `global_step`. - Raises: - ValueError: If some of the variables are not `Variable` objects. - """ + This method simply computes gradient using `tf.GradientTape` and calls + `apply_gradients()`. If you want to process the gradient before + applying then call `tf.GradientTape` and `apply_gradients()` explicitly + instead of using this function. + + Args: + loss: A callable taking no arguments which returns the value to + minimize. + var_list: list or tuple of `Variable` objects to update to + minimize `loss`, or a callable returning the list or tuple of + `Variable` objects. Use callable when the variable list would + otherwise be incomplete before `minimize` since the variables + are created at the first time `loss` is called. + grad_loss: Optional. A `Tensor` holding the gradient computed for + `loss`. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. + name: Optional name for the returned operation. + Returns: + An Operation that updates the variables in `var_list`. If + `global_step` was not `None`, that operation also increments + `global_step`. + Raises: + ValueError: If some of the variables are not `Variable` objects. + """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).minimize( loss, var_list=var_list, grad_loss=grad_loss, name=name) def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): """Apply gradients to variables. - This is the second part of `minimize()`. It returns an `Operation` that - applies gradients. - Args: - grads_and_vars: List of (gradient, variable) pairs. - name: Optional name for the returned operation. Default to the name - passed to the `Optimizer` constructor. - decay_var_list: Optional list of variables to be decayed. Defaults to all - variables in var_list. - Returns: - An `Operation` that applies the specified gradients. If `global_step` - was not None, that operation also increments `global_step`. - Raises: - TypeError: If `grads_and_vars` is malformed. - ValueError: If none of the variables have gradients. - """ + + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. + Returns: + An `Operation` that applies the specified gradients. If + `global_step` was not None, that operation also increments + `global_step`. + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).apply_gradients( grads_and_vars, name=name) @@ -176,57 +182,59 @@ def _resource_apply_sparse(self, grad, var, indices): def extend_with_decoupled_weight_decay(base_optimizer): - """Factory function returning an optimizer class with decoupled weight decay. - - Returns an optimizer class. An instance of the returned class computes the - update step of `base_optimizer` and additionally decays the weights. - E.g., the class returned by - `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to - `tf.contrib.opt.AdamWOptimizer`. - - The API of the new optimizer class slightly differs from the API of the - base optimizer: - - The first argument to the constructor is the weight decay rate. - - `minimize` and `apply_gradients` accept the optional keyword argument - `decay_var_list`, which specifies the variables that should be decayed. - If `None`, all variables that are optimized are decayed. - - Usage example: - ```python - # MyAdamW is a new class - MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) - # Create a MyAdamW object - optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) - sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) - - Note that this extension decays weights BEFORE applying the update based - on the gradient, i.e. this extension only has the desired behaviour for - optimizers which do not depend on the value of 'var' in the update step! - ``` - - Args: - base_optimizer: An optimizer class that inherits from tf.train.Optimizer. - - Returns: - A new optimizer class that inherits from DecoupledWeightDecayExtension - and base_optimizer. - """ + """Factory function returning an optimizer class with decoupled weight + decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent + to `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of 'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ @keras_utils.register_keras_custom_object class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, base_optimizer): """Base_optimizer with decoupled weight decay. - This class computes the update step of `base_optimizer` and - additionally decays the variable with the weight decay being decoupled from - the optimization steps w.r.t. to the loss function, as described by - Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). - For SGD variants, this simplifies hyperparameter search since - it decouples the settings of weight decay and learning rate. - For adaptive gradient algorithms, it regularizes variables with large - gradients more than L2 regularization would, which was shown to yield - better training loss and generalization error in the paper above. - """ + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being + decoupled from the optimization steps w.r.t. to the loss + function, as described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this + simplifies hyperparameter search since it decouples the settings + of weight decay and learning rate. For adaptive gradient + algorithms, it regularizes variables with large gradients more + than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + """ def __init__(self, weight_decay, *args, **kwargs): # super delegation is necessary here @@ -242,24 +250,24 @@ def __init__(self, weight_decay, *args, **kwargs): class SGDWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): """Optimizer that implements the Momentum algorithm with weight_decay. - This is an implementation of the SGDW optimizer described in "Fixing - Weight Decay Regularization in Adam" by Loshchilov & Hutter - (https://arxiv.org/abs/1711.05101) - ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). - It computes the update step of `train.MomentumOptimizer` and additionally - decays the variable. Note that this is different from adding - L2 regularization on the variables to the loss. Decoupling the weight decay - from other hyperparameters (in particular the learning rate) simplifies - hyperparameter search. - - For further information see the documentation of the SGD Optimizer. - - Note that this optimizer can also be instantiated as - ```python - extend_with_weight_decay(tf.keras.optimizers.SGD, - weight_decay=weight_decay) - ``` - """ + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the SGD Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + """ def __init__(self, weight_decay, @@ -270,22 +278,22 @@ def __init__(self, **kwargs): """Construct a new SGDW optimizer. - For further information see the documentation of the SGD Optimizer. - - Args: - Arguments: - learning_rate: float hyperparameter >= 0. Learning rate. - momentum: float hyperparameter >= 0 that accelerates SGD in the relevant - direction and dampens oscillations. - nesterov: boolean. Whether to apply Nesterov momentum. - name: Optional name prefix for the operations created when applying - gradients. Defaults to 'SGD'. - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ + For further information see the documentation of the SGD Optimizer. + + Args: + learning_rate: float hyperparameter >= 0. Learning rate. + momentum: float hyperparameter >= 0 that accelerates SGD in the + relevant direction and dampens oscillations. + nesterov: boolean. Whether to apply Nesterov momentum. + name: Optional name prefix for the operations created when applying + gradients. Defaults to 'SGD'. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`}. `clipnorm` is clip gradients by norm; + `clipvalue` is clip gradients by value, `decay` is included for + backward compatibility to allow time inverse decay of learning + rate. `lr` is included for backward compatibility, recommended + to use `learning_rate` instead. + """ super(SGDWOptimizer, self).__init__( weight_decay, learning_rate=learning_rate, @@ -299,24 +307,25 @@ def __init__(self, class AdamWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): """Optimizer that implements the Adam algorithm with weight decay. - This is an implementation of the AdamW optimizer described in "Fixing - Weight Decay Regularization in Adam" by Loshchilov & Hutter - (https://arxiv.org/abs/1711.05101) - ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). - It computes the update step of `train.AdamOptimizer` and additionally decays - the variable. Note that this is different from adding L2 regularization on - the variables to the loss: it regularizes variables with large - gradients more than L2 regularization would, which was shown to yield better - training loss and generalization error in the paper above. + It computes the update step of `train.AdamOptimizer` and additionally + decays the variable. Note that this is different from adding L2 + regularization on the variables to the loss: it regularizes variables with + large gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. - For further information see the documentation of the Adam Optimizer. + For further information see the documentation of the Adam Optimizer. - Note that this optimizer can also be instantiated as - ```python - extend_with_weight_decay(tf.keras.optimizers.SGD, weight_decay=weight_decay) - ``` - """ + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + """ def __init__(self, weight_decay, @@ -329,28 +338,32 @@ def __init__(self, **kwargs): """Construct a new AdamW optimizer. - For further information see the documentation of the Adam Optimizer. - - Args: - weight_decay: A Tensor or a floating point value. The weight decay. - learning_rate: A Tensor or a floating point value. The learning rate. - beta_1: A float value or a constant float tensor. The exponential decay - rate for the 1st moment estimates. - beta_2: A float value or a constant float tensor. The exponential decay - rate for the 2nd moment estimates. - epsilon: A small constant for numerical stability. This epsilon is - "epsilon hat" in the Kingma and Ba paper (in the formula just before - Section 2.1), not the epsilon in Algorithm 1 of the paper. - amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from - the paper "On the Convergence of Adam and beyond". - name: Optional name for the operations created when applying gradients. - Defaults to "AdamW". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A Tensor or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning + rate. + beta_1: A float value or a constant float tensor. The exponential + decay rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential + decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just + before Section 2.1), not the epsilon in Algorithm 1 of the + paper. + amsgrad: boolean. Whether to apply AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and + beyond". + name: Optional name for the operations created when applying + gradients. Defaults to "AdamW". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse decay + of learning rate. `lr` is included for backward compatibility, + recommended to use `learning_rate` instead. + """ super(AdamWOptimizer, self).__init__( weight_decay, learning_rate=learning_rate, From 98a42c64b240db81559f5756f99ed5cde40e8f17 Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Thu, 25 Apr 2019 14:57:34 +0200 Subject: [PATCH 6/8] Move optimizer_test_base into weight_decay_test for now. In the optimizer tests, optimizer params are now keywords instead of a dict. Fix code in comments to support tf-2.0, naming errors, line length. --- tensorflow_addons/optimizers/BUILD | 2 - tensorflow_addons/optimizers/__init__.py | 4 +- .../optimizers/weight_decay_optimizers.py | 138 ++++++++---- .../weight_decay_optimizers_test.py | 205 ++++++++++++++---- 4 files changed, 263 insertions(+), 86 deletions(-) diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 3e12a01edb..87bf51cac6 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -28,13 +28,11 @@ py_test( ], ) - py_test( name = "weight_decay_optimizers_test", size = "small", srcs = [ "weight_decay_optimizers_test.py", - "optimizer_test_base.py", ], main = "weight_decay_optimizers_test.py", srcs_version = "PY2AND3", diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index b22abe2224..88c551a5ca 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,7 +19,7 @@ from __future__ import print_function from tensorflow_addons.optimizers.lazy_adam import LazyAdam -from tensorflow_addons.optimizers.weight_decay_optimizers import AdamWOptimizer -from tensorflow_addons.optimizers.weight_decay_optimizers import SGDWOptimizer +from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW +from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW from tensorflow_addons.optimizers.weight_decay_optimizers import ( extend_with_decoupled_weight_decay) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index 913243e59b..a0b3d2f1c7 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -37,17 +37,18 @@ class DecoupledWeightDecayExtension(object): optimizers with decoupled weight decay. We explicitly define the two examples used in the above paper (SGDW and AdamW), but in general this can extend any OptimizerX by using - `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + `extend_with_decoupled_weight_decay( + OptimizerX, weight_decay=weight_decay)`. In order for it to work, it must be the first class the Optimizer with weight decay inherits from, e.g. ```python - class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): def __init__(self, weight_decay, *args, **kwargs): - super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + super(AdamW, self).__init__(weight_decay, *args, **kwargs). ``` - Note that this extension decays weights BEFORE applying the update based + Note: this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step! @@ -55,17 +56,16 @@ def __init__(self, weight_decay, *args, **kwargs): the decay to the `weight_decay` as well. For example: ```python - schedule = tf.train.piecewise_constant( - tf.train.get_global_step(), [10000, 15000], [1e-0, 1e-1, 1e-2]) - lr = 1e-1 * schedule() - wd = lambda: 1e-4 * schedule() + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) - # ... + # ... - optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, - weight_decay=wd, - momentum=0.9, - use_nesterov=True) + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) ``` """ @@ -78,10 +78,10 @@ def __init__(self, weight_decay, **kwargs): **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ + wd = kwargs.pop('weight_decay', weight_decay) super(DecoupledWeightDecayExtension, self).__init__(**kwargs) self._decay_var_list = None # is set in minimize or apply_gradients - self._set_hyper('weight_decay', kwargs.get('weight_decay', - weight_decay)) + self._set_hyper('weight_decay', wd) def get_config(self): config = super(DecoupledWeightDecayExtension, self).get_config() @@ -188,8 +188,8 @@ def extend_with_decoupled_weight_decay(base_optimizer): Returns an optimizer class. An instance of the returned class computes the update step of `base_optimizer` and additionally decays the weights. E.g., the class returned by - `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent - to `tf.contrib.opt.AdamWOptimizer`. + `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is + equivalent to `tfa.optimizers.AdamW`. The API of the new optimizer class slightly differs from the API of the base optimizer: @@ -201,18 +201,35 @@ def extend_with_decoupled_weight_decay(base_optimizer): Usage example: ```python # MyAdamW is a new class - MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) # Create a MyAdamW object optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) - sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + # update var1, var2 but only decay var1 + optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) - Note that this extension decays weights BEFORE applying the update based + Note: this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of 'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) ``` Args: - base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + base_optimizer: An optimizer class that inherits from + tf.optimizers.Optimizer. Returns: A new optimizer class that inherits from DecoupledWeightDecayExtension @@ -238,23 +255,21 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, def __init__(self, weight_decay, *args, **kwargs): # super delegation is necessary here - # pylint: disable=useless-super-delegation super(OptimizerWithDecoupledWeightDecay, self).__init__( weight_decay, *args, **kwargs) - # pylint: enable=useless-super-delegation return OptimizerWithDecoupledWeightDecay @keras_utils.register_keras_custom_object -class SGDWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): +class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): """Optimizer that implements the Momentum algorithm with weight_decay. - This is an implementation of the SGDW optimizer described in "Fixing - Weight Decay Regularization in Adam" by Loshchilov & Hutter + This is an implementation of the SGDW optimizer described in "Decoupled + Weight Decay Regularization" by Loshchilov & Hutter (https://arxiv.org/abs/1711.05101) ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). - It computes the update step of `train.MomentumOptimizer` and additionally + It computes the update step of `tf.keras.optimizers.SGD` and additionally decays the variable. Note that this is different from adding L2 regularization on the variables to the loss. Decoupling the weight decay from other hyperparameters (in particular the learning rate) simplifies @@ -262,10 +277,27 @@ class SGDWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): For further information see the documentation of the SGD Optimizer. - Note that this optimizer can also be instantiated as + This optimizer can also be instantiated as ```python - extend_with_weight_decay(tf.keras.optimizers.SGD, - weight_decay=weight_decay) + extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.SGDW( + learning_rate=lr, weight_decay=wd, momentum=0.9) ``` """ @@ -287,14 +319,14 @@ def __init__(self, nesterov: boolean. Whether to apply Nesterov momentum. name: Optional name prefix for the operations created when applying gradients. Defaults to 'SGD'. - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, - `lr`, `decay`}. `clipnorm` is clip gradients by norm; - `clipvalue` is clip gradients by value, `decay` is included for - backward compatibility to allow time inverse decay of learning - rate. `lr` is included for backward compatibility, recommended - to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse decay + of learning rate. `lr` is included for backward compatibility, + recommended to use `learning_rate` instead. """ - super(SGDWOptimizer, self).__init__( + super(SGDW, self).__init__( weight_decay, learning_rate=learning_rate, momentum=momentum, @@ -304,15 +336,15 @@ def __init__(self, @keras_utils.register_keras_custom_object -class AdamWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): +class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): """Optimizer that implements the Adam algorithm with weight decay. - This is an implementation of the AdamW optimizer described in "Fixing - Weight Decay Regularization in Adam" by Loshchilov & Hutter + This is an implementation of the AdamW optimizer described in "Decoupled + Weight Decay Regularization" by Loshchilov & Hutter (https://arxiv.org/abs/1711.05101) ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). - It computes the update step of `train.AdamOptimizer` and additionally + It computes the update step of `tf.keras.optimizers.Adam` and additionally decays the variable. Note that this is different from adding L2 regularization on the variables to the loss: it regularizes variables with large gradients more than L2 regularization would, which was shown to yield @@ -320,10 +352,26 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): For further information see the documentation of the Adam Optimizer. - Note that this optimizer can also be instantiated as + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + ```python - extend_with_weight_decay(tf.keras.optimizers.SGD, - weight_decay=weight_decay) + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) ``` """ @@ -364,7 +412,7 @@ def __init__(self, of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead. """ - super(AdamWOptimizer, self).__init__( + super(AdamW, self).__init__( weight_decay, learning_rate=learning_rate, beta_1=beta_1, diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py index c209f5788a..ce65757b29 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py @@ -22,18 +22,140 @@ import tensorflow as tf from tensorflow_addons.utils import test_utils -from tensorflow_addons.optimizers.optimizer_test_base import OptimizerTestBase from tensorflow_addons.optimizers import weight_decay_optimizers WEIGHT_DECAY = 0.01 -def adamw_update_numpy(param, grad_t, slot_vars, optimizer_params): +class OptimizerTestBase(tf.test.TestCase): + """Base class for optimizer tests. + + Optimizer tests may inherit from this class and define test + functions using doTest. Usually this should include the functions + testSparse, testBasic, and testBasicCallableParams. See + weight_decay_optimizers_test for an example. + """ + + def doTest(self, optimizer, update_fn, do_sparse=False, + **optimizer_kwargs): + """The major test function. + + Args: + optimizer: The tensorflow optimizer class to be tested. + update_fn: The numpy update function of the optimizer, the function + signature must be + update_fn(var: np.array, + grad_t: np.array, + slot_vars: dict, + **kwargs) -> (updated_var, updated_slot_vars) + Note that slot_vars will be initialized to an empty dictionary + for each variable, initial values should be handled in the + update_fn. + do_sparse: If True, test sparse update. Defaults to False, i.e., + dense update. + **optimizer_kwargs:The parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + """ + for i, dtype in enumerate([tf.half, tf.float32, tf.float64]): + # Initialize variables for numpy implementation. + np_slot_vars0, np_slot_vars1 = {}, {} + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + # Create Tensorflow variables. + var0 = tf.Variable(var0_np, name="var0_%d" % i) + var1 = tf.Variable(var1_np, name="var1_%d" % i) + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np), tf.constant(grads0_np_indices), + tf.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np), tf.constant(grads1_np_indices), + tf.constant([2])) + else: + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = optimizer(**optimizer_kwargs) + # Validate initial values. + if not tf.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Create the update op. + # Run 3 steps of the optimizer + for _ in range(3): + if tf.executing_eagerly(): + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + else: + self.evaluate(update) + var0_np, np_slot_vars0 = update_fn( + var0_np, grads0_np, np_slot_vars0, **optimizer_kwargs) + var1_np, np_slot_vars1 = update_fn( + var1_np, grads1_np, np_slot_vars1, **optimizer_kwargs) + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs): + """Test for repeated indices in sparse updates. + + This test verifies that an update with repeated indices is the same as + an update with two times the gradient. + + Args: + optimizer: The tensorflow optimizer class to be tested. + **optimizer_kwargs: The parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + """ + for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: + repeated_index_update_var = tf.Variable([[1.0], [2.0]], + dtype=dtype) + aggregated_update_var = tf.Variable([[1.0], [2.0]], dtype=dtype) + grad_repeated_index = tf.IndexedSlices( + tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + tf.constant([1, 1]), tf.constant([2, 1])) + grad_aggregated = tf.IndexedSlices( + tf.constant([0.2], shape=[1, 1], dtype=dtype), + tf.constant([1]), tf.constant([2, 1])) + opt_repeated = optimizer(**optimizer_kwargs) + repeated_update = opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated = optimizer(**optimizer_kwargs) + aggregated_update = opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + if not tf.executing_eagerly(): + self.evaluate(repeated_update) + self.evaluate(aggregated_update) + else: + opt_repeated.apply_gradients([(grad_repeated_index, + repeated_index_update_var)]) + opt_aggregated.apply_gradients([(grad_aggregated, + aggregated_update_var)]) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + + +def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2, + epsilon, weight_decay): """Numpy update function for AdamW.""" - opt_params = ( - optimizer_params[k] for k in - ["learning_rate", "beta_1", "beta_2", "epsilon", "weight_decay"]) - lr, beta1, beta2, eps, wd = (v() if callable(v) else v for v in opt_params) + lr, beta1, beta2, eps, wd = (v() if callable(v) else v + for v in (learning_rate, beta_1, beta_2, + epsilon, weight_decay)) t = slot_vars.get("t", 0) + 1 lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t @@ -44,21 +166,18 @@ def adamw_update_numpy(param, grad_t, slot_vars, optimizer_params): return param_t, slot_vars -def sgdw_update_numpy(param, grad_t, slot_vars, optimizer_params): +def sgdw_update_numpy(param, grad_t, slot_vars, learning_rate, momentum, + weight_decay): """Numpy update function for SGDW.""" m = slot_vars.get("m", 0) - optimizer_params = { - k: v() if callable(v) else v - for k, v in optimizer_params.items() - } - slot_vars["m"] = optimizer_params["momentum"] * m + grad_t - lr = optimizer_params["learning_rate"] - wd = optimizer_params["weight_decay"] + lr, momentum, wd = (v() if callable(v) else v + for v in (learning_rate, momentum, weight_decay)) + slot_vars["m"] = momentum * m + grad_t param_t = param * (1 - wd) - lr * slot_vars["m"] return param_t, slot_vars -class AdamWOptimizerTest(OptimizerTestBase): +class AdamWTest(OptimizerTestBase): opt_params = { "learning_rate": 0.001, @@ -68,58 +187,70 @@ class AdamWOptimizerTest(OptimizerTestBase): "weight_decay": WEIGHT_DECAY } callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} - optimizer = weight_decay_optimizers.AdamWOptimizer + optimizer = weight_decay_optimizers.AdamW @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): self.doTest( self.optimizer, adamw_update_numpy, - self.opt_params, - do_sparse=True) + do_sparse=True, + **self.opt_params) - @test_utils.run_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparseRepeatedIndices(self): - self.doTestSparseRepeatedIndices(self.optimizer, self.opt_params) + self.doTestSparseRepeatedIndices(self.optimizer, **self.opt_params) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): - self.doTest(self.optimizer, adamw_update_numpy, self.opt_params) + self.doTest(self.optimizer, adamw_update_numpy, **self.opt_params) def testBasicCallableParams(self): self.doTest(self.optimizer, adamw_update_numpy, - self.callable_opt_params) + **self.callable_opt_params) -class SGDWOptimizerTest(OptimizerTestBase): +class SGDWTest(OptimizerTestBase): - opt_params = { - "learning_rate": 0.001, - "momentum": 0.9, - "weight_decay": WEIGHT_DECAY - } - callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} - optimizer = weight_decay_optimizers.SGDWOptimizer + optimizer = weight_decay_optimizers.SGDW @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): self.doTest( - self.optimizer, sgdw_update_numpy, self.opt_params, do_sparse=True) + self.optimizer, + sgdw_update_numpy, + do_sparse=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) - @test_utils.run_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparseRepeatedIndices(self): - self.doTestSparseRepeatedIndices(self.optimizer, self.opt_params) + self.doTestSparseRepeatedIndices( + self.optimizer, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): - self.doTest(self.optimizer, sgdw_update_numpy, self.opt_params) + self.doTest( + self.optimizer, + sgdw_update_numpy, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) def testBasicCallableParams(self): - self.doTest(self.optimizer, sgdw_update_numpy, - self.callable_opt_params) + self.doTest( + self.optimizer, + sgdw_update_numpy, + learning_rate=lambda: 0.001, + momentum=lambda: 0.9, + weight_decay=lambda: WEIGHT_DECAY) -class ExtendWithWeightDecayTest(SGDWOptimizerTest): +class ExtendWithWeightDecayTest(SGDWTest): """Verify that the factory function SGDW is the same as SGDW.""" optimizer = weight_decay_optimizers.extend_with_decoupled_weight_decay( From 077924c684802cb89bbd0ce5d5dcb89c14e7e18d Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Mon, 29 Apr 2019 09:56:49 +0200 Subject: [PATCH 7/8] Delete optimizer_test_base.py Remove keras object registration in the factory function. --- tensorflow_addons/optimizers/README.md | 1 - .../optimizers/optimizer_test_base.py | 149 ------------------ .../optimizers/weight_decay_optimizers.py | 61 +++---- .../weight_decay_optimizers_test.py | 135 ++++++++-------- 4 files changed, 104 insertions(+), 242 deletions(-) delete mode 100644 tensorflow_addons/optimizers/optimizer_test_base.py diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 27b5054668..268b9f6f6a 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -27,7 +27,6 @@ must: `@run_in_graph_and_eager_modes` (for test method) or `run_all_in_graph_and_eager_modes` (for TestCase subclass) decorator. - * Consider inheriting from `OptimizerTestBase`. * Add a `py_test` to this sub-package's BUILD file. #### Documentation Requirements diff --git a/tensorflow_addons/optimizers/optimizer_test_base.py b/tensorflow_addons/optimizers/optimizer_test_base.py deleted file mode 100644 index 6bef3e9480..0000000000 --- a/tensorflow_addons/optimizers/optimizer_test_base.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Base class for optimizer tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - - -class OptimizerTestBase(tf.test.TestCase): - """Base class for optimizer tests. - - Optimizer tests may inherit from this class and define test - functions using doTest. Usually this should include the functions - testSparse, testBasic, and testBasicCallableParams. See - weight_decay_optimizers_test for an example. - """ - - def doTest(self, optimizer, update_fn, params, do_sparse=False): - """The major test function. - - Args: - optimizer: The tensorflow optimizer class to be tested. - update_fn: The numpy update function of the optimizer, the function - signature must be - update_fn(var: np.array, - grad_t: np.array, - slot_vars: dict, - optimizer_params: dict) -> (updated_var, - updated_slot_vars) - Note that slot_vars will be initialized to an empty dictionary - for each variable, initial values should be handled in the - update_fn. - params: A dict, the parameters to pass to the construcor of the - optimizer. Either a constant or a callable. This also passed to - the optimizer_params in the update_fn. - do_sparse: If True, test sparse update. Defaults to False, i.e., - dense update. - """ - for i, dtype in enumerate([tf.half, tf.float32, tf.float64]): - with self.session(graph=tf.Graph()): - # Initialize variables for numpy implementation. - np_slot_vars0, np_slot_vars1 = {}, {} - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - # Create Tensorflow variables. - var0 = tf.Variable(var0_np, name="var0_%d" % i) - var1 = tf.Variable(var1_np, name="var1_%d" % i) - if do_sparse: - grads0_np_indices = np.array([0, 1], dtype=np.int32) - grads0 = tf.IndexedSlices( - tf.constant(grads0_np), tf.constant(grads0_np_indices), - tf.constant([2])) - grads1_np_indices = np.array([0, 1], dtype=np.int32) - grads1 = tf.IndexedSlices( - tf.constant(grads1_np), tf.constant(grads1_np_indices), - tf.constant([2])) - else: - grads0 = tf.constant(grads0_np) - grads1 = tf.constant(grads1_np) - opt = optimizer(**params) - # Validate initial values. - if not tf.executing_eagerly(): - update = opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - self.evaluate(tf.compat.v1.global_variables_initializer()) - self.assertAllClose([1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - # Create the update op. - # Run 3 steps of the optimizer - for _ in range(3): - if tf.executing_eagerly(): - opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - else: - self.evaluate(update) - var0_np, np_slot_vars0 = update_fn(var0_np, grads0_np, - np_slot_vars0, params) - var1_np, np_slot_vars1 = update_fn(var1_np, grads1_np, - np_slot_vars1, params) - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, - self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, - self.evaluate(var1)) - - def doTestSparseRepeatedIndices(self, optimizer, params): - """Test for repeated indices in sparse updates. - - This test verifies that an update with repeated indices is the same as - an update with two times the gradient. - - Args: - optimizer: The tensorflow optimizer class to be tested. - params: A dict, the parameters to pass to the construcor of the - optimizer. Either a constant or a callable. This also passed to - the optimizer_params in the update_fn. - """ - for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: - with self.cached_session(): - repeated_index_update_var = tf.Variable([[1.0], [2.0]], - dtype=dtype) - aggregated_update_var = tf.Variable([[1.0], [2.0]], - dtype=dtype) - grad_repeated_index = tf.IndexedSlices( - tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), - tf.constant([1, 1]), tf.constant([2, 1])) - grad_aggregated = tf.IndexedSlices( - tf.constant([0.2], shape=[1, 1], dtype=dtype), - tf.constant([1]), tf.constant([2, 1])) - opt_repeated = optimizer(**params) - repeated_update = opt_repeated.apply_gradients( - [(grad_repeated_index, repeated_index_update_var)]) - opt_aggregated = optimizer(**params) - aggregated_update = opt_aggregated.apply_gradients( - [(grad_aggregated, aggregated_update_var)]) - self.evaluate(tf.compat.v1.global_variables_initializer()) - self.assertAllClose( - self.evaluate(aggregated_update_var), - self.evaluate(repeated_index_update_var)) - for _ in range(3): - if not tf.executing_eagerly(): - self.evaluate(repeated_update) - self.evaluate(aggregated_update) - else: - opt_repeated.apply_gradients( - [(grad_repeated_index, repeated_index_update_var)]) - opt_aggregated.apply_gradients( - [(grad_aggregated, aggregated_update_var)]) - self.assertAllClose( - self.evaluate(aggregated_update_var), - self.evaluate(repeated_index_update_var)) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index a0b3d2f1c7..2c222e1024 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -125,8 +125,11 @@ def minimize(self, ValueError: If some of the variables are not `Variable` objects. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False - return super(DecoupledWeightDecayExtension, self).minimize( - loss, var_list=var_list, grad_loss=grad_loss, name=name) + return super(DecoupledWeightDecayExtension, + self).minimize(loss, + var_list=var_list, + grad_loss=grad_loss, + name=name) def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): """Apply gradients to variables. @@ -149,8 +152,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): ValueError: If none of the variables have gradients. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False - return super(DecoupledWeightDecayExtension, self).apply_gradients( - grads_and_vars, name=name) + return super(DecoupledWeightDecayExtension, + self).apply_gradients(grads_and_vars, name=name) def _decay_weights_op(self, var): if not self._decay_var_list or var in self._decay_var_list: @@ -161,8 +164,8 @@ def _decay_weights_op(self, var): def _decay_weights_sparse_op(self, var, indices): if not self._decay_var_list or var in self._decay_var_list: - update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather( - var, indices)) + update = (-self._get_hyper('weight_decay', var.dtype) * + tf.gather(var, indices)) return self._resource_scatter_add(var, indices, update) return tf.no_op() @@ -226,17 +229,19 @@ def extend_with_decoupled_weight_decay(base_optimizer): optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) ``` + + Note: you might want to register your own custom optimizer using + `tf.keras.utils.get_custom_objects()`. Args: - base_optimizer: An optimizer class that inherits from - tf.optimizers.Optimizer. + base_optimizer: An optimizer class that inherits from + tf.optimizers.Optimizer. Returns: - A new optimizer class that inherits from DecoupledWeightDecayExtension - and base_optimizer. + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. """ - @keras_utils.register_keras_custom_object class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, base_optimizer): """Base_optimizer with decoupled weight decay. @@ -255,8 +260,8 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, def __init__(self, weight_decay, *args, **kwargs): # super delegation is necessary here - super(OptimizerWithDecoupledWeightDecay, self).__init__( - weight_decay, *args, **kwargs) + super(OptimizerWithDecoupledWeightDecay, + self).__init__(weight_decay, *args, **kwargs) return OptimizerWithDecoupledWeightDecay @@ -326,13 +331,12 @@ def __init__(self, of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead. """ - super(SGDW, self).__init__( - weight_decay, - learning_rate=learning_rate, - momentum=momentum, - nesterov=nesterov, - name=name, - **kwargs) + super(SGDW, self).__init__(weight_decay, + learning_rate=learning_rate, + momentum=momentum, + nesterov=nesterov, + name=name, + **kwargs) @keras_utils.register_keras_custom_object @@ -412,12 +416,11 @@ def __init__(self, of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead. """ - super(AdamW, self).__init__( - weight_decay, - learning_rate=learning_rate, - beta_1=beta_1, - beta_2=beta_2, - epsilon=epsilon, - amsgrad=amsgrad, - name=name, - **kwargs) + super(AdamW, self).__init__(weight_decay, + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py index ce65757b29..0d1b7e824f 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py @@ -69,13 +69,13 @@ def doTest(self, optimizer, update_fn, do_sparse=False, var1 = tf.Variable(var1_np, name="var1_%d" % i) if do_sparse: grads0_np_indices = np.array([0, 1], dtype=np.int32) - grads0 = tf.IndexedSlices( - tf.constant(grads0_np), tf.constant(grads0_np_indices), - tf.constant([2])) + grads0 = tf.IndexedSlices(tf.constant(grads0_np), + tf.constant(grads0_np_indices), + tf.constant([2])) grads1_np_indices = np.array([0, 1], dtype=np.int32) - grads1 = tf.IndexedSlices( - tf.constant(grads1_np), tf.constant(grads1_np_indices), - tf.constant([2])) + grads1 = tf.IndexedSlices(tf.constant(grads1_np), + tf.constant(grads1_np_indices), + tf.constant([2])) else: grads0 = tf.constant(grads0_np) grads1 = tf.constant(grads1_np) @@ -94,10 +94,12 @@ def doTest(self, optimizer, update_fn, do_sparse=False, opt.apply_gradients(zip([grads0, grads1], [var0, var1])) else: self.evaluate(update) - var0_np, np_slot_vars0 = update_fn( - var0_np, grads0_np, np_slot_vars0, **optimizer_kwargs) - var1_np, np_slot_vars1 = update_fn( - var1_np, grads1_np, np_slot_vars1, **optimizer_kwargs) + var0_np, np_slot_vars0 = update_fn(var0_np, grads0_np, + np_slot_vars0, + **optimizer_kwargs) + var1_np, np_slot_vars1 = update_fn(var1_np, grads1_np, + np_slot_vars1, + **optimizer_kwargs) # Validate updated params self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) @@ -127,15 +129,16 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs): tf.constant([0.2], shape=[1, 1], dtype=dtype), tf.constant([1]), tf.constant([2, 1])) opt_repeated = optimizer(**optimizer_kwargs) - repeated_update = opt_repeated.apply_gradients( - [(grad_repeated_index, repeated_index_update_var)]) + repeated_update = opt_repeated.apply_gradients([ + (grad_repeated_index, repeated_index_update_var) + ]) opt_aggregated = optimizer(**optimizer_kwargs) - aggregated_update = opt_aggregated.apply_gradients( - [(grad_aggregated, aggregated_update_var)]) + aggregated_update = opt_aggregated.apply_gradients([ + (grad_aggregated, aggregated_update_var) + ]) self.evaluate(tf.compat.v1.global_variables_initializer()) - self.assertAllClose( - self.evaluate(aggregated_update_var), - self.evaluate(repeated_index_update_var)) + self.assertAllClose(self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) for _ in range(3): if not tf.executing_eagerly(): self.evaluate(repeated_update) @@ -145,9 +148,8 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs): repeated_index_update_var)]) opt_aggregated.apply_gradients([(grad_aggregated, aggregated_update_var)]) - self.assertAllClose( - self.evaluate(aggregated_update_var), - self.evaluate(repeated_index_update_var)) + self.assertAllClose(self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2, @@ -160,8 +162,8 @@ def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2, lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t slot_vars["v"] = beta2 * slot_vars.get("v", 0) + (1 - beta2) * grad_t**2 - param_t = (param * (1 - wd) - - lr_t * slot_vars["m"] / (np.sqrt(slot_vars["v"]) + eps)) + param_t = (param * (1 - wd) - lr_t * slot_vars["m"] / + (np.sqrt(slot_vars["v"]) + eps)) slot_vars["t"] = t return param_t, slot_vars @@ -179,35 +181,46 @@ def sgdw_update_numpy(param, grad_t, slot_vars, learning_rate, momentum, class AdamWTest(OptimizerTestBase): - opt_params = { - "learning_rate": 0.001, - "beta_1": 0.9, - "beta_2": 0.999, - "epsilon": 1e-8, - "weight_decay": WEIGHT_DECAY - } - callable_opt_params = {k: (lambda: v) for k, v in opt_params.items()} optimizer = weight_decay_optimizers.AdamW @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): - self.doTest( - self.optimizer, - adamw_update_numpy, - do_sparse=True, - **self.opt_params) + self.doTest(self.optimizer, + adamw_update_numpy, + do_sparse=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparseRepeatedIndices(self): - self.doTestSparseRepeatedIndices(self.optimizer, **self.opt_params) + self.doTestSparseRepeatedIndices(self.optimizer, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): - self.doTest(self.optimizer, adamw_update_numpy, **self.opt_params) + self.doTest(self.optimizer, + adamw_update_numpy, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) def testBasicCallableParams(self): - self.doTest(self.optimizer, adamw_update_numpy, - **self.callable_opt_params) + self.doTest(self.optimizer, + adamw_update_numpy, + learning_rate=lambda: 0.001, + beta_1=lambda: 0.9, + beta_2=lambda: 0.999, + epsilon=lambda: 1e-8, + weight_decay=lambda: WEIGHT_DECAY) class SGDWTest(OptimizerTestBase): @@ -216,38 +229,34 @@ class SGDWTest(OptimizerTestBase): @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): - self.doTest( - self.optimizer, - sgdw_update_numpy, - do_sparse=True, - learning_rate=0.001, - momentum=0.9, - weight_decay=WEIGHT_DECAY) + self.doTest(self.optimizer, + sgdw_update_numpy, + do_sparse=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparseRepeatedIndices(self): - self.doTestSparseRepeatedIndices( - self.optimizer, - learning_rate=0.001, - momentum=0.9, - weight_decay=WEIGHT_DECAY) + self.doTestSparseRepeatedIndices(self.optimizer, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): - self.doTest( - self.optimizer, - sgdw_update_numpy, - learning_rate=0.001, - momentum=0.9, - weight_decay=WEIGHT_DECAY) + self.doTest(self.optimizer, + sgdw_update_numpy, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) def testBasicCallableParams(self): - self.doTest( - self.optimizer, - sgdw_update_numpy, - learning_rate=lambda: 0.001, - momentum=lambda: 0.9, - weight_decay=lambda: WEIGHT_DECAY) + self.doTest(self.optimizer, + sgdw_update_numpy, + learning_rate=lambda: 0.001, + momentum=lambda: 0.9, + weight_decay=lambda: WEIGHT_DECAY) class ExtendWithWeightDecayTest(SGDWTest): From d6ccae6b03e989cead45e33f25aa7c185a5be9e9 Mon Sep 17 00:00:00 2001 From: Philipp Jund Date: Tue, 30 Apr 2019 10:33:25 +0200 Subject: [PATCH 8/8] Fix code formatting via patch file. --- .../optimizers/weight_decay_optimizers.py | 49 +++--- .../weight_decay_optimizers_test.py | 150 +++++++++--------- 2 files changed, 102 insertions(+), 97 deletions(-) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index 2c222e1024..4759538d9a 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -125,11 +125,8 @@ def minimize(self, ValueError: If some of the variables are not `Variable` objects. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False - return super(DecoupledWeightDecayExtension, - self).minimize(loss, - var_list=var_list, - grad_loss=grad_loss, - name=name) + return super(DecoupledWeightDecayExtension, self).minimize( + loss, var_list=var_list, grad_loss=grad_loss, name=name) def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): """Apply gradients to variables. @@ -152,8 +149,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): ValueError: If none of the variables have gradients. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False - return super(DecoupledWeightDecayExtension, - self).apply_gradients(grads_and_vars, name=name) + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, name=name) def _decay_weights_op(self, var): if not self._decay_var_list or var in self._decay_var_list: @@ -164,8 +161,8 @@ def _decay_weights_op(self, var): def _decay_weights_sparse_op(self, var, indices): if not self._decay_var_list or var in self._decay_var_list: - update = (-self._get_hyper('weight_decay', var.dtype) * - tf.gather(var, indices)) + update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather( + var, indices)) return self._resource_scatter_add(var, indices, update) return tf.no_op() @@ -260,8 +257,8 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, def __init__(self, weight_decay, *args, **kwargs): # super delegation is necessary here - super(OptimizerWithDecoupledWeightDecay, - self).__init__(weight_decay, *args, **kwargs) + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) return OptimizerWithDecoupledWeightDecay @@ -331,12 +328,13 @@ def __init__(self, of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead. """ - super(SGDW, self).__init__(weight_decay, - learning_rate=learning_rate, - momentum=momentum, - nesterov=nesterov, - name=name, - **kwargs) + super(SGDW, self).__init__( + weight_decay, + learning_rate=learning_rate, + momentum=momentum, + nesterov=nesterov, + name=name, + **kwargs) @keras_utils.register_keras_custom_object @@ -416,11 +414,12 @@ def __init__(self, of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead. """ - super(AdamW, self).__init__(weight_decay, - learning_rate=learning_rate, - beta_1=beta_1, - beta_2=beta_2, - epsilon=epsilon, - amsgrad=amsgrad, - name=name, - **kwargs) + super(AdamW, self).__init__( + weight_decay, + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py index 0d1b7e824f..e265eecb3c 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers_test.py @@ -69,13 +69,13 @@ def doTest(self, optimizer, update_fn, do_sparse=False, var1 = tf.Variable(var1_np, name="var1_%d" % i) if do_sparse: grads0_np_indices = np.array([0, 1], dtype=np.int32) - grads0 = tf.IndexedSlices(tf.constant(grads0_np), - tf.constant(grads0_np_indices), - tf.constant([2])) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np), tf.constant(grads0_np_indices), + tf.constant([2])) grads1_np_indices = np.array([0, 1], dtype=np.int32) - grads1 = tf.IndexedSlices(tf.constant(grads1_np), - tf.constant(grads1_np_indices), - tf.constant([2])) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np), tf.constant(grads1_np_indices), + tf.constant([2])) else: grads0 = tf.constant(grads0_np) grads1 = tf.constant(grads1_np) @@ -94,12 +94,10 @@ def doTest(self, optimizer, update_fn, do_sparse=False, opt.apply_gradients(zip([grads0, grads1], [var0, var1])) else: self.evaluate(update) - var0_np, np_slot_vars0 = update_fn(var0_np, grads0_np, - np_slot_vars0, - **optimizer_kwargs) - var1_np, np_slot_vars1 = update_fn(var1_np, grads1_np, - np_slot_vars1, - **optimizer_kwargs) + var0_np, np_slot_vars0 = update_fn( + var0_np, grads0_np, np_slot_vars0, **optimizer_kwargs) + var1_np, np_slot_vars1 = update_fn( + var1_np, grads1_np, np_slot_vars1, **optimizer_kwargs) # Validate updated params self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) @@ -129,16 +127,15 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs): tf.constant([0.2], shape=[1, 1], dtype=dtype), tf.constant([1]), tf.constant([2, 1])) opt_repeated = optimizer(**optimizer_kwargs) - repeated_update = opt_repeated.apply_gradients([ - (grad_repeated_index, repeated_index_update_var) - ]) + repeated_update = opt_repeated.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) opt_aggregated = optimizer(**optimizer_kwargs) - aggregated_update = opt_aggregated.apply_gradients([ - (grad_aggregated, aggregated_update_var) - ]) + aggregated_update = opt_aggregated.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) self.evaluate(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(self.evaluate(aggregated_update_var), - self.evaluate(repeated_index_update_var)) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) for _ in range(3): if not tf.executing_eagerly(): self.evaluate(repeated_update) @@ -148,8 +145,9 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs): repeated_index_update_var)]) opt_aggregated.apply_gradients([(grad_aggregated, aggregated_update_var)]) - self.assertAllClose(self.evaluate(aggregated_update_var), - self.evaluate(repeated_index_update_var)) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2, @@ -162,8 +160,8 @@ def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2, lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t slot_vars["v"] = beta2 * slot_vars.get("v", 0) + (1 - beta2) * grad_t**2 - param_t = (param * (1 - wd) - lr_t * slot_vars["m"] / - (np.sqrt(slot_vars["v"]) + eps)) + param_t = (param * (1 - wd) - + lr_t * slot_vars["m"] / (np.sqrt(slot_vars["v"]) + eps)) slot_vars["t"] = t return param_t, slot_vars @@ -185,42 +183,46 @@ class AdamWTest(OptimizerTestBase): @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): - self.doTest(self.optimizer, - adamw_update_numpy, - do_sparse=True, - learning_rate=0.001, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-8, - weight_decay=WEIGHT_DECAY) + self.doTest( + self.optimizer, + adamw_update_numpy, + do_sparse=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparseRepeatedIndices(self): - self.doTestSparseRepeatedIndices(self.optimizer, - learning_rate=0.001, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-8, - weight_decay=WEIGHT_DECAY) + self.doTestSparseRepeatedIndices( + self.optimizer, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): - self.doTest(self.optimizer, - adamw_update_numpy, - learning_rate=0.001, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-8, - weight_decay=WEIGHT_DECAY) + self.doTest( + self.optimizer, + adamw_update_numpy, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY) def testBasicCallableParams(self): - self.doTest(self.optimizer, - adamw_update_numpy, - learning_rate=lambda: 0.001, - beta_1=lambda: 0.9, - beta_2=lambda: 0.999, - epsilon=lambda: 1e-8, - weight_decay=lambda: WEIGHT_DECAY) + self.doTest( + self.optimizer, + adamw_update_numpy, + learning_rate=lambda: 0.001, + beta_1=lambda: 0.9, + beta_2=lambda: 0.999, + epsilon=lambda: 1e-8, + weight_decay=lambda: WEIGHT_DECAY) class SGDWTest(OptimizerTestBase): @@ -229,34 +231,38 @@ class SGDWTest(OptimizerTestBase): @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparse(self): - self.doTest(self.optimizer, - sgdw_update_numpy, - do_sparse=True, - learning_rate=0.001, - momentum=0.9, - weight_decay=WEIGHT_DECAY) + self.doTest( + self.optimizer, + sgdw_update_numpy, + do_sparse=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testSparseRepeatedIndices(self): - self.doTestSparseRepeatedIndices(self.optimizer, - learning_rate=0.001, - momentum=0.9, - weight_decay=WEIGHT_DECAY) + self.doTestSparseRepeatedIndices( + self.optimizer, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testBasic(self): - self.doTest(self.optimizer, - sgdw_update_numpy, - learning_rate=0.001, - momentum=0.9, - weight_decay=WEIGHT_DECAY) + self.doTest( + self.optimizer, + sgdw_update_numpy, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY) def testBasicCallableParams(self): - self.doTest(self.optimizer, - sgdw_update_numpy, - learning_rate=lambda: 0.001, - momentum=lambda: 0.9, - weight_decay=lambda: WEIGHT_DECAY) + self.doTest( + self.optimizer, + sgdw_update_numpy, + learning_rate=lambda: 0.001, + momentum=lambda: 0.9, + weight_decay=lambda: WEIGHT_DECAY) class ExtendWithWeightDecayTest(SGDWTest):