From 01262fa6d2404eaec9229f6995994479b7fb7acc Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Fri, 13 Sep 2019 14:43:35 +0800 Subject: [PATCH 01/13] Add Rectified Adam optimizer --- tensorflow_addons/optimizers/BUILD | 14 + tensorflow_addons/optimizers/README.md | 1 + tensorflow_addons/optimizers/__init__.py | 1 + .../optimizers/rectified_adam.py | 250 ++++++++++++++++++ .../optimizers/rectified_adam_test.py | 89 +++++++ 5 files changed, 355 insertions(+) create mode 100644 tensorflow_addons/optimizers/rectified_adam.py create mode 100644 tensorflow_addons/optimizers/rectified_adam_test.py diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 1e49a0f1bf..fab80cc6a7 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -9,6 +9,7 @@ py_library( "lazy_adam.py", "moving_average.py", "weight_decay_optimizers.py", + "rectified_adam.py", ], srcs_version = "PY2AND3", deps = [ @@ -54,3 +55,16 @@ py_test( ":optimizers", ], ) + +py_test( + name = "rectified_adam_test", + size = "small", + srcs = [ + "rectified_adam_test.py", + ], + main = "rectified_adam_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index f45cb5fb1c..f8db199dde 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -14,6 +14,7 @@ | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | | moving_average | MovingAverage | | | weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | +| rectified_adam | RectifiedAdam | https://arxiv.org/pdf/1908.03265v1.pdf | ## Contribution Guidelines diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index ccb5eda3cc..cff8d9517b 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -24,3 +24,4 @@ from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW from tensorflow_addons.optimizers.weight_decay_optimizers import ( extend_with_decoupled_weight_decay) +from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py new file mode 100644 index 0000000000..4beee6c5ad --- /dev/null +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -0,0 +1,250 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Rectified Adam (RAdam) optimizer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python import ops +from tensorflow.python.ops import math_ops, state_ops, control_flow_ops +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +class RectifiedAdam(tf.keras.optimizers.Optimizer): + """Variant of the Adam optimizer whose adaptive learning rate is rectified so as to + have a consistent variance. + + It implements the Rectified Adam (a.k.a. RAdam) proposed by Liyuan Liu et al. in + [On The Variance Of The Adaptive Learning Rate And Beyond] + (https://arxiv.org/pdf/1908.03265v1.pdf). + + Note: `amsgrad` is not described in the original paper. Use it with caution. + """ + + def __init__(self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=0., + amsgrad=False, + total_steps=0, + warmup_proportion=0.1, + min_lr=0., + name='RectifiedAdam', + **kwargs): + r"""Construct a new RAdam optimizer. + Args: + learning_rate: A Tensor or a floating point value. The learning rate. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + weight_decay: A floating point value. Weight decay for each param. + amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from + the paper "On the Convergence of Adam and beyond". + total_steps: An integer. Total number of training steps. + Enable warmup by setting a positive value. + warmup_proportion: A floating point value. The proportion of increasing steps. + min_lr: A floating point value. Minimum learning rate after warmup. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". @compatibility(eager) When eager execution is + enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be + a callable that takes no arguments and returns the actual value to use. + This can be useful for changing these values across different + invocations of optimizer functions. @end_compatibility + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, + `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip + gradients by value, `decay` is included for backward compatibility to + allow time inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + super(RectifiedAdam, self).__init__(name, **kwargs) + self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) + self._set_hyper('beta_1', beta_1) + self._set_hyper('beta_2', beta_2) + self._set_hyper('decay', self._initial_decay) + self._set_hyper('weight_decay', weight_decay) + self._set_hyper('total_steps', float(total_steps)) + self._set_hyper('warmup_proportion', warmup_proportion) + self._set_hyper('min_lr', min_lr) + self.epsilon = epsilon or tf.keras.backend.epsilon() + self.amsgrad = amsgrad + self._initial_weight_decay = weight_decay + self._initial_total_steps = total_steps + + def _create_slots(self, var_list): + for var in var_list: + self.add_slot(var, 'm') + for var in var_list: + self.add_slot(var, 'v') + if self.amsgrad: + for var in var_list: + self.add_slot(var, 'vhat') + + def set_weights(self, weights): + params = self.weights + num_vars = int((len(params) - 1) / 2) + if len(weights) == 3 * num_vars + 1: + weights = weights[:len(params)] + super(RectifiedAdam, self).set_weights(weights) + + def _resource_apply_dense(self, grad, var): + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + m = self.get_slot(var, 'm') + v = self.get_slot(var, 'v') + beta_1_t = self._get_hyper('beta_1', var_dtype) + beta_2_t = self._get_hyper('beta_2', var_dtype) + epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) + local_step = math_ops.cast(self.iterations + 1, var_dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_power = math_ops.pow(beta_2_t, local_step) + + if self._initial_total_steps > 0: + total_steps = self._get_hyper('total_steps', var_dtype) + warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) + min_lr = self._get_hyper('min_lr', var_dtype) + decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) + decay_rate = (min_lr - lr_t) / decay_steps + lr_t = tf.where( + local_step <= warmup_steps, + lr_t * (local_step / warmup_steps), + lr_t + decay_rate * math_ops.minimum(local_step - warmup_steps, decay_steps), + ) + + sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 + sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) + + m_t = state_ops.assign(m, + beta_1_t * m + (1.0 - beta_1_t) * grad, + use_locking=self._use_locking) + m_corr_t = m_t / (1.0 - beta_1_power) + + v_t = state_ops.assign(v, + beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad), + use_locking=self._use_locking) + if self.amsgrad: + vhat = self.get_slot(var, 'vhat') + vhat_t = state_ops.assign(vhat, + math_ops.maximum(vhat, v_t), + use_locking=self._use_locking) + v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power)) + else: + v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power)) + + r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * + (sma_t - 2.0) / (sma_inf - 2.0) * + sma_inf / sma_t) + + var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) + + if self._initial_weight_decay > 0.0: + var_t += self._get_hyper('weight_decay', var_dtype) * var + + var_update = state_ops.assign_sub(var, + lr_t * var_t, + use_locking=self._use_locking) + + updates = [var_update, m_t, v_t] + if self.amsgrad: + updates.append(vhat_t) + return control_flow_ops.group(*updates) + + def _resource_apply_sparse(self, grad, var, indices): + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + beta_1_t = self._get_hyper('beta_1', var_dtype) + beta_2_t = self._get_hyper('beta_2', var_dtype) + epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) + local_step = math_ops.cast(self.iterations + 1, var_dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_power = math_ops.pow(beta_2_t, local_step) + + if self._initial_total_steps > 0: + total_steps = self._get_hyper('total_steps', var_dtype) + warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) + min_lr = self._get_hyper('min_lr', var_dtype) + decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) + decay_rate = (min_lr - lr_t) / decay_steps + lr_t = tf.where( + local_step <= warmup_steps, + lr_t * (local_step / warmup_steps), + lr_t + decay_rate * math_ops.minimum(local_step - warmup_steps, decay_steps), + ) + + sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 + sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) + + m = self.get_slot(var, 'm') + m_scaled_g_values = grad * (1 - beta_1_t) + m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) + m_corr_t = m_t / (1.0 - beta_1_power) + + v = self.get_slot(var, 'v') + v_scaled_g_values = (grad * grad) * (1 - beta_2_t) + v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) + + if self.amsgrad: + vhat = self.get_slot(var, 'vhat') + vhat_t = state_ops.assign(vhat, + math_ops.maximum(vhat, v_t), + use_locking=self._use_locking) + v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power)) + else: + v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power)) + + r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * + (sma_t - 2.0) / (sma_inf - 2.0) * + sma_inf / sma_t) + + var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) + + if self._initial_weight_decay > 0.0: + var_t += self._get_hyper('weight_decay', var_dtype) * var + + var_update = state_ops.assign_sub(var, + lr_t * var_t, + use_locking=self._use_locking) + + updates = [var_update, m_t, v_t] + if self.amsgrad: + updates.append(vhat_t) + return control_flow_ops.group(*updates) + + def get_config(self): + config = super(RectifiedAdam, self).get_config() + config.update({ + 'learning_rate': self._serialize_hyperparameter('learning_rate'), + 'beta_1': self._serialize_hyperparameter('beta_1'), + 'beta_2': self._serialize_hyperparameter('beta_2'), + 'decay': self._serialize_hyperparameter('decay'), + 'weight_decay': self._serialize_hyperparameter('weight_decay'), + 'epsilon': self.epsilon, + 'amsgrad': self.amsgrad, + 'total_steps': self._serialize_hyperparameter('total_steps'), + 'warmup_proportion': self._serialize_hyperparameter('warmup_proportion'), + 'min_lr': self._serialize_hyperparameter('min_lr'), + }) + return config diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py new file mode 100644 index 0000000000..9976c0a0c1 --- /dev/null +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -0,0 +1,89 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Rectified Adam optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_addons.utils import test_utils +from tensorflow_addons.optimizers import RectifiedAdam + + +@test_utils.run_all_in_graph_and_eager_modes +class RectifiedAdamTest(tf.test.TestCase): + + def test_dense_sample(self): + var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) + var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) + + grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32) + grad_1 = tf.constant([0.03, 0.04], dtype=tf.dtypes.float32) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + opt = RectifiedAdam(lr=1e-3) + + if tf.executing_eagerly(): + for _ in range(1000): + opt.apply_gradients(grads_and_vars) + else: + update = opt.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + for _ in range(1000): + self.evaluate(update) + + # Expected values are obtained from the official implementation + self.assertAllClose(var_0.read_value(), [0.5554, 1.5549], atol=1e-4) + self.assertAllClose(var_1.read_value(), [2.5557, 3.5557], atol=1e-4) + + def test_sparse_sample(self): + var_0 = tf.Variable([1.0, 2.0]) + var_1 = tf.Variable([3.0, 4.0]) + + grad_0 = tf.IndexedSlices( + tf.constant([0.1]), + tf.constant([0]), + tf.constant([2]) + ) + grad_1 = tf.IndexedSlices( + tf.constant([0.04]), + tf.constant([1]), + tf.constant([2]) + ) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + opt = RectifiedAdam(lr=1e-3) + + if tf.executing_eagerly(): + for _ in range(5000): + opt.apply_gradients(grads_and_vars) + else: + update = opt.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + for _ in range(5000): + self.evaluate(update) + + # Expected values are obtained from the official implementation + # Dense results should be: [-2.9875, -1.9880], [-0.9871, 0.0128] + self.assertAllClose(var_0.read_value(), [-2.9875, 2.0], atol=1e-4) + self.assertAllClose(var_1.read_value(), [3.0, 0.0128], atol=1e-4) + + +if __name__ == '__main__': + tf.test.main() From 22bd67395d4631e33d6391591ed0838b698aa5f4 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Fri, 13 Sep 2019 16:01:12 +0800 Subject: [PATCH 02/13] Add tests for amsgrad and weight decay --- .../optimizers/rectified_adam.py | 11 ++- .../optimizers/rectified_adam_test.py | 81 +++++++++++++++---- 2 files changed, 74 insertions(+), 18 deletions(-) diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index 4beee6c5ad..66b2161f7f 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -19,7 +19,7 @@ import tensorflow as tf from tensorflow.python import ops -from tensorflow.python.ops import math_ops, state_ops, control_flow_ops +from tensorflow.python.ops import math_ops, state_ops, array_ops, control_flow_ops from tensorflow_addons.utils import keras_utils @@ -224,9 +224,12 @@ def _resource_apply_sparse(self, grad, var, indices): if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var - var_update = state_ops.assign_sub(var, - lr_t * var_t, - use_locking=self._use_locking) + var_t *= lr_t + with ops.control_dependencies([var_t]): + var_update = state_ops.scatter_sub(var, + indices, + array_ops.gather(var_t, indices), + use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self.amsgrad: diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index 9976c0a0c1..51a2ccf845 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -27,7 +27,7 @@ @test_utils.run_all_in_graph_and_eager_modes class RectifiedAdamTest(tf.test.TestCase): - def test_dense_sample(self): + def run_dense_sample(self, iterations, expected, **opt_kwargs): var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) @@ -36,22 +36,21 @@ def test_dense_sample(self): grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) - opt = RectifiedAdam(lr=1e-3) + opt = RectifiedAdam(**opt_kwargs) if tf.executing_eagerly(): - for _ in range(1000): + for _ in range(iterations): opt.apply_gradients(grads_and_vars) else: update = opt.apply_gradients(grads_and_vars) self.evaluate(tf.compat.v1.global_variables_initializer()) - for _ in range(1000): + for _ in range(iterations): self.evaluate(update) - # Expected values are obtained from the official implementation - self.assertAllClose(var_0.read_value(), [0.5554, 1.5549], atol=1e-4) - self.assertAllClose(var_1.read_value(), [2.5557, 3.5557], atol=1e-4) + self.assertAllClose(var_0.read_value(), expected[0], atol=1e-4) + self.assertAllClose(var_1.read_value(), expected[1], atol=1e-4) - def test_sparse_sample(self): + def run_sparse_sample(self, iterations, expected, **opt_kwargs): var_0 = tf.Variable([1.0, 2.0]) var_1 = tf.Variable([3.0, 4.0]) @@ -68,21 +67,75 @@ def test_sparse_sample(self): grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) - opt = RectifiedAdam(lr=1e-3) + opt = RectifiedAdam(**opt_kwargs) if tf.executing_eagerly(): - for _ in range(5000): + for _ in range(iterations): opt.apply_gradients(grads_and_vars) else: update = opt.apply_gradients(grads_and_vars) self.evaluate(tf.compat.v1.global_variables_initializer()) - for _ in range(5000): + for _ in range(iterations): self.evaluate(update) + self.assertAllClose(var_0.read_value(), expected[0], atol=1e-4) + self.assertAllClose(var_1.read_value(), expected[1], atol=1e-4) + + def test_dense_sample(self): + # Expected values are obtained from the official implementation + self.run_dense_sample( + iterations=1000, + expected=[[0.5554, 1.5549], [2.5557, 3.5557]], + lr=1e-3, + ) + + def test_sparse_sample(self): + # Expected values are obtained from the official implementation + # Dense results should be: [-0.1929, 0.8066], [1.8075, 2.8074] + self.run_sparse_sample( + iterations=2000, + expected=[[-0.1929, 2.0], [3.0, 2.8074]], + lr=1e-3, + ) + + def test_dense_sample_with_amsgrad(self): + # Expected values are obtained from the official implementation + # `amsgrad` has no effect because the gradient is fixed + self.run_dense_sample( + iterations=1000, + expected=[[0.5554, 1.5549], [2.5557, 3.5557]], + lr=1e-3, + amsgrad=True, + ) + + def test_sparse_sample_with_amsgrad(self): # Expected values are obtained from the official implementation - # Dense results should be: [-2.9875, -1.9880], [-0.9871, 0.0128] - self.assertAllClose(var_0.read_value(), [-2.9875, 2.0], atol=1e-4) - self.assertAllClose(var_1.read_value(), [3.0, 0.0128], atol=1e-4) + # `amsgrad` has no effect because the gradient is fixed + self.run_sparse_sample( + iterations=2000, + expected=[[-0.1929, 2.0], [3.0, 2.8074]], + lr=1e-3, + amsgrad=True, + ) + + def test_dense_sample_with_weight_decay(self): + # Expected values are obtained from the official implementation + self.run_dense_sample( + iterations=1000, + expected=[[0.5472, 1.5368], [2.5276, 3.5176]], + lr=1e-3, + weight_decay=0.01, + ) + + def test_sparse_sample_with_weight_decay(self): + # Expected values are obtained from the official implementation + # Dense results should be: [-0.2029, 0.7768], [1.7578, 2.7380] + self.run_sparse_sample( + iterations=2000, + expected=[[-0.2029, 2.0], [3.0, 2.7380]], + lr=1e-3, + weight_decay=0.01, + ) if __name__ == '__main__': From 8a8a217d48b880c4e5078bc53e62736c756b3b71 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Fri, 13 Sep 2019 17:04:06 +0800 Subject: [PATCH 03/13] Add tests of warmup for RAdam optimizer --- tensorflow_addons/optimizers/README.md | 1 + .../optimizers/rectified_adam.py | 108 ++++++++++++------ .../optimizers/rectified_adam_test.py | 27 ++++- 3 files changed, 99 insertions(+), 37 deletions(-) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index f8db199dde..d33ca541b7 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -6,6 +6,7 @@ | lazy_adam | Saishruthi Swaminathan | saishruthi.tn@gmail.com | | moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | | weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | +| rectified_adam | Zhao Hanguang | cyberzhg@gmail.com | ## Components diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index 66b2161f7f..d9edbe638e 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -19,20 +19,44 @@ import tensorflow as tf from tensorflow.python import ops -from tensorflow.python.ops import math_ops, state_ops, array_ops, control_flow_ops +from tensorflow.python.ops import (math_ops, state_ops, + array_ops, control_flow_ops) from tensorflow_addons.utils import keras_utils @keras_utils.register_keras_custom_object class RectifiedAdam(tf.keras.optimizers.Optimizer): - """Variant of the Adam optimizer whose adaptive learning rate is rectified so as to - have a consistent variance. + """Variant of the Adam optimizer whose adaptive learning rate is rectified + so as to have a consistent variance. - It implements the Rectified Adam (a.k.a. RAdam) proposed by Liyuan Liu et al. in - [On The Variance Of The Adaptive Learning Rate And Beyond] - (https://arxiv.org/pdf/1908.03265v1.pdf). + It implements the Rectified Adam (a.k.a. RAdam) proposed by + Liyuan Liu et al. in [On The Variance Of The Adaptive Learning Rate + And Beyond](https://arxiv.org/pdf/1908.03265v1.pdf). + + Example of usage: + + ```python + opt = tfa.optimizers.RectifiedAdam(lr=1e-3) + ``` Note: `amsgrad` is not described in the original paper. Use it with caution. + + RAdam is not a placement of the heuristic warmup, the settings should be + kept if warmup has already been employed and tuned in the baseline method. + You can enable warmup by setting `total_steps` and `warmup_proportion`: + + ```python + opt = tfa.optimizers.RectifiedAdam( + lr=1e-3, + total_steps=10000, + warmup_proportion=0.1, + min_lr=1e-5, + ) + ``` + + In the above example, the learning rate will increase linearly + from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr` + in 9000 steps. """ def __init__(self, @@ -48,33 +72,31 @@ def __init__(self, name='RectifiedAdam', **kwargs): r"""Construct a new RAdam optimizer. + Args: - learning_rate: A Tensor or a floating point value. The learning rate. - beta_1: A float value or a constant float tensor. The exponential decay - rate for the 1st moment estimates. - beta_2: A float value or a constant float tensor. The exponential decay - rate for the 2nd moment estimates. - epsilon: A small constant for numerical stability. This epsilon is - "epsilon hat" in the Kingma and Ba paper (in the formula just before - Section 2.1), not the epsilon in Algorithm 1 of the paper. + learning_rate: A Tensor or a floating point value. + The learning rate. + beta_1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. weight_decay: A floating point value. Weight decay for each param. - amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from - the paper "On the Convergence of Adam and beyond". + amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm + from the paper "On the Convergence of Adam and beyond". total_steps: An integer. Total number of training steps. Enable warmup by setting a positive value. - warmup_proportion: A floating point value. The proportion of increasing steps. + warmup_proportion: A floating point value. + The proportion of increasing steps. min_lr: A floating point value. Minimum learning rate after warmup. - name: Optional name for the operations created when applying gradients. - Defaults to "Adam". @compatibility(eager) When eager execution is - enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be - a callable that takes no arguments and returns the actual value to use. - This can be useful for changing these values across different - invocations of optimizer functions. @end_compatibility - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. + name: Optional name for the operations created when applying + gradients. Defaults to "RectifiedAdam". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`}. `clipnorm` is clip gradients by norm; + `clipvalue` is clip gradients by value, `decay` is included for + backward compatibility to allow time inverse decay of learning + rate. `lr` is included for backward compatibility, recommended + to use `learning_rate` instead. """ super(RectifiedAdam, self).__init__(name, **kwargs) self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) @@ -120,14 +142,17 @@ def _resource_apply_dense(self, grad, var): if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) - warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) + warmup_steps = total_steps *\ + self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), - lr_t + decay_rate * math_ops.minimum(local_step - warmup_steps, decay_steps), + lr_t + decay_rate * math_ops.minimum( + local_step - warmup_steps, + decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 @@ -139,7 +164,8 @@ def _resource_apply_dense(self, grad, var): m_corr_t = m_t / (1.0 - beta_1_power) v_t = state_ops.assign(v, - beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad), + beta_2_t * v + + (1.0 - beta_2_t) * math_ops.square(grad), use_locking=self._use_locking) if self.amsgrad: vhat = self.get_slot(var, 'vhat') @@ -154,7 +180,10 @@ def _resource_apply_dense(self, grad, var): (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) - var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) + var_t = tf.where( + sma_t >= 5.0, + r_t * m_corr_t / (v_corr_t + epsilon_t), + m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var @@ -180,14 +209,17 @@ def _resource_apply_sparse(self, grad, var, indices): if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) - warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) + warmup_steps = total_steps *\ + self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), - lr_t + decay_rate * math_ops.minimum(local_step - warmup_steps, decay_steps), + lr_t + decay_rate * math_ops.minimum( + local_step - warmup_steps, + decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 @@ -219,7 +251,10 @@ def _resource_apply_sparse(self, grad, var, indices): (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) - var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) + var_t = tf.where( + sma_t >= 5.0, + r_t * m_corr_t / (v_corr_t + epsilon_t), + m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var @@ -247,7 +282,8 @@ def get_config(self): 'epsilon': self.epsilon, 'amsgrad': self.amsgrad, 'total_steps': self._serialize_hyperparameter('total_steps'), - 'warmup_proportion': self._serialize_hyperparameter('warmup_proportion'), + 'warmup_proportion': + self._serialize_hyperparameter('warmup_proportion'), 'min_lr': self._serialize_hyperparameter('min_lr'), }) return config diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index 51a2ccf845..87f301bfbe 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -46,7 +46,6 @@ def run_dense_sample(self, iterations, expected, **opt_kwargs): self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(iterations): self.evaluate(update) - self.assertAllClose(var_0.read_value(), expected[0], atol=1e-4) self.assertAllClose(var_1.read_value(), expected[1], atol=1e-4) @@ -137,6 +136,32 @@ def test_sparse_sample_with_weight_decay(self): weight_decay=0.01, ) + def test_dense_sample_with_warmup(self): + self.run_dense_sample( + iterations=1000, + expected=[[0.8041, 1.8041], [2.8041, 3.8041]], + lr=1e-3, + total_steps=1000, + warmup_proportion=0.1, + min_lr=1e-5, + ) + + def test_sparse_sample_with_warmup(self): + self.run_sparse_sample( + iterations=2000, + expected=[[0.4653, 2.0], [3.0, 3.4653]], + lr=1e-3, + total_steps=2000, + warmup_proportion=0.1, + min_lr=1e-5, + ) + + def test_get_config(self): + opt = RectifiedAdam(lr=1e-4) + config = opt.get_config() + self.assertEqual(config['learning_rate'], 1e-4) + self.assertEqual(config['total_steps'], 0) + if __name__ == '__main__': tf.test.main() From be08a4a88a6e0c7b00a9da591ab47fc1b89e4cf7 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Fri, 13 Sep 2019 20:48:27 +0800 Subject: [PATCH 04/13] Add lookahead for RAdam --- .../optimizers/rectified_adam.py | 116 ++++++++++++++++-- .../optimizers/rectified_adam_test.py | 25 ++++ 2 files changed, 129 insertions(+), 12 deletions(-) diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index d9edbe638e..29fefd8bb9 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -57,6 +57,23 @@ class RectifiedAdam(tf.keras.optimizers.Optimizer): In the above example, the learning rate will increase linearly from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr` in 9000 steps. + + Lookahead, proposed by Michael R. Zhang et.al in the paper + [Lookahead Optimizer: k steps forward, 1 step back] + (https://arxiv.org/abs/1907.08610v1), can be integrated with RAdam, + which is announced by Less Wright and the new combined optimizer can also be + called "Ranger". By setting `lookahead_step` and `lookahead_ratio`, the + lookahead mechanism will be used. For example: + + ```python + opt = tfa.optimizers.RectifiedAdam( + lr=1e-3, + lookahead_step=5, + lookahead_ratio=0.5, + ) + ``` + + Note: more memory will be used since all the variables will be duplicated. """ def __init__(self, @@ -66,9 +83,12 @@ def __init__(self, epsilon=1e-7, weight_decay=0., amsgrad=False, + sma_threshold=5.0, total_steps=0, warmup_proportion=0.1, min_lr=0., + lookahead_step=0, + lookahead_ratio=0.5, name='RectifiedAdam', **kwargs): r"""Construct a new RAdam optimizer. @@ -84,11 +104,15 @@ def __init__(self, weight_decay: A floating point value. Weight decay for each param. amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and beyond". + sma_threshold. A float value. The threshold for simple mean average. total_steps: An integer. Total number of training steps. Enable warmup by setting a positive value. warmup_proportion: A floating point value. The proportion of increasing steps. min_lr: A floating point value. Minimum learning rate after warmup. + lookahead_step: An integer. Synchronization period of lookhead. + Enable lookahead mechanism by setting it with a positive value. + lookahead_ratio: A float value. Slow weights step size. name: Optional name for the operations created when applying gradients. Defaults to "RectifiedAdam". **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, @@ -104,13 +128,17 @@ def __init__(self, self._set_hyper('beta_2', beta_2) self._set_hyper('decay', self._initial_decay) self._set_hyper('weight_decay', weight_decay) + self._set_hyper('sma_threshold', sma_threshold) self._set_hyper('total_steps', float(total_steps)) self._set_hyper('warmup_proportion', warmup_proportion) self._set_hyper('min_lr', min_lr) + self._set_hyper('lookahead_step', lookahead_step) + self._set_hyper('lookahead_ratio', lookahead_ratio) self.epsilon = epsilon or tf.keras.backend.epsilon() self.amsgrad = amsgrad self._initial_weight_decay = weight_decay self._initial_total_steps = total_steps + self._initial_lookahead_step = lookahead_step def _create_slots(self, var_list): for var in var_list: @@ -120,6 +148,9 @@ def _create_slots(self, var_list): if self.amsgrad: for var in var_list: self.add_slot(var, 'vhat') + if self._initial_lookahead_step: + for var in var_list: + self.add_slot(var, 'slow') def set_weights(self, weights): params = self.weights @@ -180,19 +211,45 @@ def _resource_apply_dense(self, grad, var): (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) + sma_threshold = self._get_hyper('sma_threshold', var_dtype) var_t = tf.where( - sma_t >= 5.0, + sma_t >= sma_threshold, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var - var_update = state_ops.assign_sub(var, - lr_t * var_t, - use_locking=self._use_locking) + if self._initial_lookahead_step > 0: + slow_var = self.get_slot(var, 'slow') + lookahead_step = self._get_hyper('lookahead_step', local_step.dtype) + lookahead_ratio = self._get_hyper('lookahead_ratio', var_dtype) + sync_cond = math_ops.equal(local_step % lookahead_step, 0) + new_var = var - lr_t * var_t + slow_init = tf.where( + tf.equal(local_step, tf.constant(1, dtype=local_step.dtype)), + var, + slow_var, + ) + slow_t = slow_init + (new_var - slow_var) * lookahead_ratio + var_updates = [ + state_ops.assign( + slow_var, + tf.where(sync_cond, slow_t, slow_init), + use_locking=self._use_locking, + ), + state_ops.assign( + var, + tf.where(sync_cond, slow_t, new_var), + use_locking=self._use_locking, + ), + ] + else: + var_updates = [state_ops.assign_sub(var, + lr_t * var_t, + use_locking=self._use_locking)] - updates = [var_update, m_t, v_t] + updates = var_updates + [m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates) @@ -251,8 +308,9 @@ def _resource_apply_sparse(self, grad, var, indices): (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) + sma_threshold = self._get_hyper('sma_threshold', var_dtype) var_t = tf.where( - sma_t >= 5.0, + sma_t >= sma_threshold, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) @@ -261,12 +319,41 @@ def _resource_apply_sparse(self, grad, var, indices): var_t *= lr_t with ops.control_dependencies([var_t]): - var_update = state_ops.scatter_sub(var, - indices, - array_ops.gather(var_t, indices), - use_locking=self._use_locking) - - updates = [var_update, m_t, v_t] + if self._initial_lookahead_step > 0: + slow_var = self.get_slot(var, 'slow') + lookahead_step = self._get_hyper('lookahead_step', local_step.dtype) + lookahead_ratio = self._get_hyper('lookahead_ratio', var_dtype) + sync_cond = math_ops.equal(local_step % lookahead_step, 0) + new_var = var - var_t + slow_init = tf.where( + tf.equal(local_step, tf.constant(1, dtype=local_step.dtype)), + var, + slow_var, + ) + slow_t = slow_init + (new_var - slow_var) * lookahead_ratio + var_updates = [ + state_ops.assign( + slow_var, + tf.where(sync_cond, slow_t, slow_init), + use_locking=self._use_locking, + ), + state_ops.scatter_update( + var, + indices, + array_ops.gather( + tf.where(sync_cond, slow_t, new_var), indices, + ), + use_locking=self._use_locking, + ), + ] + else: + var_updates = [state_ops.scatter_sub( + var, + indices, + array_ops.gather(var_t, indices), + use_locking=self._use_locking)] + + updates = var_updates + [m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates) @@ -279,11 +366,16 @@ def get_config(self): 'beta_2': self._serialize_hyperparameter('beta_2'), 'decay': self._serialize_hyperparameter('decay'), 'weight_decay': self._serialize_hyperparameter('weight_decay'), + 'sma_threshold': self._serialize_hyperparameter('sma_threshold'), 'epsilon': self.epsilon, 'amsgrad': self.amsgrad, 'total_steps': self._serialize_hyperparameter('total_steps'), 'warmup_proportion': self._serialize_hyperparameter('warmup_proportion'), + 'lookahead_step': + self._serialize_hyperparameter('lookahead_step'), + 'lookahead_ratio': + self._serialize_hyperparameter('lookahead_ratio'), 'min_lr': self._serialize_hyperparameter('min_lr'), }) return config diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index 87f301bfbe..bc4f61c976 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -156,6 +156,31 @@ def test_sparse_sample_with_warmup(self): min_lr=1e-5, ) + def test_dense_sample_with_lookahead(self): + # Expected values are obtained from the original implementation + # of Ranger + self.run_dense_sample( + iterations=1000, + expected=[[0.7985, 1.7983], [2.7987, 3.7986]], + lr=1e-3, + beta_1=0.95, + lookahead_step=6, + lookahead_ratio=0.45, + ) + + def test_sparse_sample_with_lookahead(self): + # Expected values are obtained from the original implementation + # of Ranger. + # Dense results should be: [0.6417, 1.6415], [2.6419, 3.6418] + self.run_sparse_sample( + iterations=1500, + expected=[[0.6417, 2.0], [3.0, 3.6418]], + lr=1e-3, + beta_1=0.95, + lookahead_step=6, + lookahead_ratio=0.45, + ) + def test_get_config(self): opt = RectifiedAdam(lr=1e-4) config = opt.get_config() From 289f0d846af2cec9b8e40885c2f817ed25633022 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Sat, 14 Sep 2019 15:08:00 +0800 Subject: [PATCH 05/13] Decouple lookahead optimizer --- tensorflow_addons/optimizers/BUILD | 14 ++ tensorflow_addons/optimizers/README.md | 2 + tensorflow_addons/optimizers/__init__.py | 1 + tensorflow_addons/optimizers/lookahead.py | 177 ++++++++++++++++++ .../optimizers/lookahead_test.py | 127 +++++++++++++ .../optimizers/rectified_adam.py | 103 ++-------- .../optimizers/rectified_adam_test.py | 79 ++++---- 7 files changed, 378 insertions(+), 125 deletions(-) create mode 100644 tensorflow_addons/optimizers/lookahead.py create mode 100644 tensorflow_addons/optimizers/lookahead_test.py diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index fab80cc6a7..44f0b5be93 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -10,6 +10,7 @@ py_library( "moving_average.py", "weight_decay_optimizers.py", "rectified_adam.py", + "lookahead.py", ], srcs_version = "PY2AND3", deps = [ @@ -68,3 +69,16 @@ py_test( ":optimizers", ], ) + +py_test( + name = "lookahead_test", + size = "small", + srcs = [ + "lookahead_test.py", + ], + main = "lookahead_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index d33ca541b7..7c194d8784 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -7,6 +7,7 @@ | moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | | weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | | rectified_adam | Zhao Hanguang | cyberzhg@gmail.com | +| lookahead | Zhao Hanguang | cyberzhg@gmail.com | ## Components @@ -16,6 +17,7 @@ | moving_average | MovingAverage | | | weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | | rectified_adam | RectifiedAdam | https://arxiv.org/pdf/1908.03265v1.pdf | +| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 | ## Contribution Guidelines diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index cff8d9517b..e1a45cd430 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -25,3 +25,4 @@ from tensorflow_addons.optimizers.weight_decay_optimizers import ( extend_with_decoupled_weight_decay) from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam +from tensorflow_addons.optimizers.lookahead import Lookahead diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py new file mode 100644 index 0000000000..f15cc5b992 --- /dev/null +++ b/tensorflow_addons/optimizers/lookahead.py @@ -0,0 +1,177 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python import ops +from tensorflow.python.ops import math_ops, state_ops, control_flow_ops +from tensorflow.python.keras import optimizers +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +class Lookahead(tf.keras.optimizers.Optimizer): + """This class allows to extend optimizers with the lookahead mechanism. + + The mechanism is proposed by Michael R. Zhang et.al in the paper + [Lookahead Optimizer: k steps forward, 1 step back] + (https://arxiv.org/abs/1907.08610v1). The optimizer iteratively updates two + sets of weights: the search directions for weights are chosen by the inner + optimizer, while the "slow weights" are updated each `k` steps based on the + directions of the "fast weights" and the two sets of weights are + synchronized. This method improves the learning stability and lowers the + variance of its inner optimizer. + + Example of usage: + + ```python + opt = tf.keras.optimizers.SGD(learning_rate) + opt = tfa.optimizers.Lookahead(opt) + ``` + """ + + def __init__(self, + optimizer, + k=6, + alpha=0.5, + name="Lookahead", + **kwargs): + r"""Wrap optimizer with the lookahead mechanism. + + Args: + optimizer: A Tensor or a floating point value. + The learning rate. + k: An integer. Synchronization period of lookahead. + Enable lookahead mechanism by setting it with a positive value. + alpha: A float value. Slow weights step size. + name: Optional name for the operations created when applying + gradients. Defaults to "RectifiedAdam". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`}. `clipnorm` is clip gradients by norm; + `clipvalue` is clip gradients by value, `decay` is included for + backward compatibility to allow time inverse decay of learning + rate. `lr` is included for backward compatibility, recommended + to use `learning_rate` instead. + """ + super(Lookahead, self).__init__(name, **kwargs) + + if isinstance(optimizer, str): + optimizer = optimizers.get(optimizer) + if not isinstance(optimizer, tf.keras.optimizers.Optimizer): + raise TypeError( + "optimizer is not an object of tf.keras.optimizers.Optimizer") + + self._optimizer = optimizer + self._set_hyper('k', k) + self._set_hyper('alpha', alpha) + self._initialized = False + + def _create_slots(self, var_list): + for var in var_list: + self.add_slot(var, 'slow') + + def apply_gradients(self, grads_and_vars, name=None): + var_list = [v for (_, v) in grads_and_vars] + + with tf.keras.backend.name_scope(self._scope_ctx): + with ops.init_scope(): + _ = self.iterations + self._create_hypers() + self._create_slots(var_list) + self._prepare(var_list) + + if self._initialized: + init_op = tf.no_op() + else: + self._initialized = True + init_op = self._init_op(var_list) + + with tf.control_dependencies([init_op]): + train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) + with tf.control_dependencies([train_op]): + lookahead_updates = [self._look_ahead_op(var) for var in var_list] + lookahead_op = control_flow_ops.group(lookahead_updates) + + return control_flow_ops.group(init_op, train_op, lookahead_op) + + def _init_op(self, var_list): + updates = [] + iterations = self._optimizer.iterations + for var in var_list: + slow_var = self.get_slot(var, 'slow') + updates.append(state_ops.assign( + slow_var, + tf.where( + math_ops.equal(iterations, + tf.constant(0, dtype=iterations.dtype)), + var, + slow_var, + ), + use_locking=self._use_locking)) + return control_flow_ops.group(*updates) + + def _look_ahead_op(self, var): + var_dtype = var.dtype.base_dtype + slow_var = self.get_slot(var, 'slow') + local_step = math_ops.cast(self._optimizer.iterations, var_dtype) + k = self._get_hyper('k', local_step.dtype) + alpha = self._get_hyper('alpha', var_dtype) + step_back = slow_var + alpha * (var - slow_var) + sync_cond = math_ops.equal(local_step % k, 0) + slow_update = state_ops.assign(slow_var, tf.where( + sync_cond, + step_back, + slow_var, + ), use_locking=self._use_locking) + var_update = state_ops.assign(var, tf.where( + sync_cond, + step_back, + var, + ), use_locking=self._use_locking) + return control_flow_ops.group(slow_update, var_update) + + @property + def weights(self): + return self._optimizer.weights + + def _resource_apply_dense(self, grad, var): + return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access + + def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): + return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access + grad, var, indices) + + def _resource_apply_sparse(self, grad, var, indices): + return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access + + def get_config(self): + config = { + 'optimizer': optimizers.serialize(self._optimizer), + 'k': self._serialize_hyperparameter('k'), + 'alpha': self._serialize_hyperparameter('alpha'), + } + base_config = super(Lookahead, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + optimizer = optimizers.deserialize( + config.pop('optimizer'), + custom_objects=custom_objects, + ) + return cls(optimizer, **config) diff --git a/tensorflow_addons/optimizers/lookahead_test.py b/tensorflow_addons/optimizers/lookahead_test.py new file mode 100644 index 0000000000..ddabc30232 --- /dev/null +++ b/tensorflow_addons/optimizers/lookahead_test.py @@ -0,0 +1,127 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Lookahead optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.python.keras import optimizers + +from tensorflow_addons.utils import test_utils +from tensorflow_addons.optimizers import Lookahead + + +@test_utils.run_all_in_graph_and_eager_modes +class LookaheadTest(tf.test.TestCase): + + def run_dense_sample(self, iterations, optimizer, seed=0x2019): + np.random.seed(seed) + + val_0 = np.random.random((2,)) + val_1 = np.random.random((2,)) + + var_0 = tf.Variable(val_0, dtype=tf.dtypes.float32) + var_1 = tf.Variable(val_1, dtype=tf.dtypes.float32) + + grad_0 = tf.constant( + np.random.standard_normal((2,)), dtype=tf.dtypes.float32) + grad_1 = tf.constant( + np.random.standard_normal((2,)), dtype=tf.dtypes.float32) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + if tf.executing_eagerly(): + for _ in range(iterations): + optimizer.apply_gradients(grads_and_vars) + else: + update = optimizer.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + for _ in range(iterations): + self.evaluate(update) + + return [val_0, val_1], [self.evaluate(var_0), self.evaluate(var_1)] + + def run_sparse_sample(self, iterations, optimizer, seed=0x2019): + np.random.seed(seed) + + val_0 = np.random.random((2,)) + val_1 = np.random.random((2,)) + + var_0 = tf.Variable(val_0, dtype=tf.dtypes.float32) + var_1 = tf.Variable(val_1, dtype=tf.dtypes.float32) + + grad_0 = tf.IndexedSlices( + tf.constant([np.random.standard_normal()]), + tf.constant([0]), + tf.constant([2]) + ) + grad_1 = tf.IndexedSlices( + tf.constant([np.random.standard_normal()]), + tf.constant([1]), + tf.constant([2]) + ) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + if tf.executing_eagerly(): + for _ in range(iterations): + optimizer.apply_gradients(grads_and_vars) + else: + update = optimizer.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + for _ in range(iterations): + self.evaluate(update) + + return [val_0, val_1], [self.evaluate(var_0), self.evaluate(var_1)] + + def test_dense_exact_ratio(self): + for k in [5, 10, 100, 500]: + for alpha in [0.1, 0.5, 0.9]: + optimizer = optimizers.get('adam') + vals, quick_vars = self.run_dense_sample(k, optimizer) + optimizer = Lookahead('adam', k=k, alpha=alpha) + _, slow_vars = self.run_dense_sample(k, optimizer) + for val, quick, slow in zip(vals, quick_vars, slow_vars): + expected = val + (quick - val) * alpha + self.assertAllClose(expected, slow) + + def test_sparse_exact_ratio(self): + for k in [5, 10, 100, 500]: + for alpha in [0.1, 0.5, 0.9]: + optimizer = optimizers.get('adam') + vals, quick_vars = self.run_sparse_sample(k, optimizer) + optimizer = Lookahead('adam', k=k, alpha=alpha) + _, slow_vars = self.run_sparse_sample(k, optimizer) + for val, quick, slow in zip(vals, quick_vars, slow_vars): + expected = val + (quick - val) * alpha + self.assertAllClose(expected, slow) + + def test_invalid_optimizer_type(self): + with self.assertRaises(TypeError): + Lookahead(optimizers.Adam()) + + def test_get_config(self): + opt = Lookahead('adam', k=10, alpha=0.4) + opt = optimizers.deserialize(optimizers.serialize(opt)) + config = opt.get_config() + self.assertEqual(config['k'], 10) + self.assertEqual(config['alpha'], 0.4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index 29fefd8bb9..309de40052 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -62,18 +62,13 @@ class RectifiedAdam(tf.keras.optimizers.Optimizer): [Lookahead Optimizer: k steps forward, 1 step back] (https://arxiv.org/abs/1907.08610v1), can be integrated with RAdam, which is announced by Less Wright and the new combined optimizer can also be - called "Ranger". By setting `lookahead_step` and `lookahead_ratio`, the - lookahead mechanism will be used. For example: + called "Ranger". The mechanism can be enabled by using the lookahead + wrapper. For example: ```python - opt = tfa.optimizers.RectifiedAdam( - lr=1e-3, - lookahead_step=5, - lookahead_ratio=0.5, - ) + radam = tfa.optimizers.RectifiedAdam() + ranger = tfa.optimizers.Lookahead(radam, k=6, alpha=0.5) ``` - - Note: more memory will be used since all the variables will be duplicated. """ def __init__(self, @@ -87,8 +82,6 @@ def __init__(self, total_steps=0, warmup_proportion=0.1, min_lr=0., - lookahead_step=0, - lookahead_ratio=0.5, name='RectifiedAdam', **kwargs): r"""Construct a new RAdam optimizer. @@ -110,9 +103,6 @@ def __init__(self, warmup_proportion: A floating point value. The proportion of increasing steps. min_lr: A floating point value. Minimum learning rate after warmup. - lookahead_step: An integer. Synchronization period of lookhead. - Enable lookahead mechanism by setting it with a positive value. - lookahead_ratio: A float value. Slow weights step size. name: Optional name for the operations created when applying gradients. Defaults to "RectifiedAdam". **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, @@ -132,13 +122,10 @@ def __init__(self, self._set_hyper('total_steps', float(total_steps)) self._set_hyper('warmup_proportion', warmup_proportion) self._set_hyper('min_lr', min_lr) - self._set_hyper('lookahead_step', lookahead_step) - self._set_hyper('lookahead_ratio', lookahead_ratio) self.epsilon = epsilon or tf.keras.backend.epsilon() self.amsgrad = amsgrad self._initial_weight_decay = weight_decay self._initial_total_steps = total_steps - self._initial_lookahead_step = lookahead_step def _create_slots(self, var_list): for var in var_list: @@ -148,9 +135,6 @@ def _create_slots(self, var_list): if self.amsgrad: for var in var_list: self.add_slot(var, 'vhat') - if self._initial_lookahead_step: - for var in var_list: - self.add_slot(var, 'slow') def set_weights(self, weights): params = self.weights @@ -220,36 +204,11 @@ def _resource_apply_dense(self, grad, var): if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var - if self._initial_lookahead_step > 0: - slow_var = self.get_slot(var, 'slow') - lookahead_step = self._get_hyper('lookahead_step', local_step.dtype) - lookahead_ratio = self._get_hyper('lookahead_ratio', var_dtype) - sync_cond = math_ops.equal(local_step % lookahead_step, 0) - new_var = var - lr_t * var_t - slow_init = tf.where( - tf.equal(local_step, tf.constant(1, dtype=local_step.dtype)), - var, - slow_var, - ) - slow_t = slow_init + (new_var - slow_var) * lookahead_ratio - var_updates = [ - state_ops.assign( - slow_var, - tf.where(sync_cond, slow_t, slow_init), - use_locking=self._use_locking, - ), - state_ops.assign( - var, - tf.where(sync_cond, slow_t, new_var), - use_locking=self._use_locking, - ), - ] - else: - var_updates = [state_ops.assign_sub(var, - lr_t * var_t, - use_locking=self._use_locking)] + var_update = state_ops.assign_sub(var, + lr_t * var_t, + use_locking=self._use_locking) - updates = var_updates + [m_t, v_t] + updates = [var_update + m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates) @@ -319,41 +278,13 @@ def _resource_apply_sparse(self, grad, var, indices): var_t *= lr_t with ops.control_dependencies([var_t]): - if self._initial_lookahead_step > 0: - slow_var = self.get_slot(var, 'slow') - lookahead_step = self._get_hyper('lookahead_step', local_step.dtype) - lookahead_ratio = self._get_hyper('lookahead_ratio', var_dtype) - sync_cond = math_ops.equal(local_step % lookahead_step, 0) - new_var = var - var_t - slow_init = tf.where( - tf.equal(local_step, tf.constant(1, dtype=local_step.dtype)), - var, - slow_var, - ) - slow_t = slow_init + (new_var - slow_var) * lookahead_ratio - var_updates = [ - state_ops.assign( - slow_var, - tf.where(sync_cond, slow_t, slow_init), - use_locking=self._use_locking, - ), - state_ops.scatter_update( - var, - indices, - array_ops.gather( - tf.where(sync_cond, slow_t, new_var), indices, - ), - use_locking=self._use_locking, - ), - ] - else: - var_updates = [state_ops.scatter_sub( - var, - indices, - array_ops.gather(var_t, indices), - use_locking=self._use_locking)] - - updates = var_updates + [m_t, v_t] + var_update = state_ops.scatter_sub( + var, + indices, + array_ops.gather(var_t, indices), + use_locking=self._use_locking) + + updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates) @@ -372,10 +303,6 @@ def get_config(self): 'total_steps': self._serialize_hyperparameter('total_steps'), 'warmup_proportion': self._serialize_hyperparameter('warmup_proportion'), - 'lookahead_step': - self._serialize_hyperparameter('lookahead_step'), - 'lookahead_ratio': - self._serialize_hyperparameter('lookahead_ratio'), 'min_lr': self._serialize_hyperparameter('min_lr'), }) return config diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index bc4f61c976..3f0a2a3741 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -21,13 +21,13 @@ import tensorflow as tf from tensorflow_addons.utils import test_utils -from tensorflow_addons.optimizers import RectifiedAdam +from tensorflow_addons.optimizers import RectifiedAdam, Lookahead @test_utils.run_all_in_graph_and_eager_modes class RectifiedAdamTest(tf.test.TestCase): - def run_dense_sample(self, iterations, expected, **opt_kwargs): + def run_dense_sample(self, iterations, expected, optimizer): var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) @@ -36,20 +36,19 @@ def run_dense_sample(self, iterations, expected, **opt_kwargs): grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) - opt = RectifiedAdam(**opt_kwargs) - if tf.executing_eagerly(): for _ in range(iterations): - opt.apply_gradients(grads_and_vars) + optimizer.apply_gradients(grads_and_vars) else: - update = opt.apply_gradients(grads_and_vars) + update = optimizer.apply_gradients(grads_and_vars) self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(iterations): self.evaluate(update) + self.assertAllClose(var_0.read_value(), expected[0], atol=1e-4) self.assertAllClose(var_1.read_value(), expected[1], atol=1e-4) - def run_sparse_sample(self, iterations, expected, **opt_kwargs): + def run_sparse_sample(self, iterations, expected, optimizer): var_0 = tf.Variable([1.0, 2.0]) var_1 = tf.Variable([3.0, 4.0]) @@ -66,13 +65,11 @@ def run_sparse_sample(self, iterations, expected, **opt_kwargs): grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) - opt = RectifiedAdam(**opt_kwargs) - if tf.executing_eagerly(): for _ in range(iterations): - opt.apply_gradients(grads_and_vars) + optimizer.apply_gradients(grads_and_vars) else: - update = opt.apply_gradients(grads_and_vars) + update = optimizer.apply_gradients(grads_and_vars) self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(iterations): self.evaluate(update) @@ -85,7 +82,7 @@ def test_dense_sample(self): self.run_dense_sample( iterations=1000, expected=[[0.5554, 1.5549], [2.5557, 3.5557]], - lr=1e-3, + optimizer=RectifiedAdam(lr=1e-3), ) def test_sparse_sample(self): @@ -94,7 +91,7 @@ def test_sparse_sample(self): self.run_sparse_sample( iterations=2000, expected=[[-0.1929, 2.0], [3.0, 2.8074]], - lr=1e-3, + optimizer=RectifiedAdam(lr=1e-3), ) def test_dense_sample_with_amsgrad(self): @@ -103,8 +100,7 @@ def test_dense_sample_with_amsgrad(self): self.run_dense_sample( iterations=1000, expected=[[0.5554, 1.5549], [2.5557, 3.5557]], - lr=1e-3, - amsgrad=True, + optimizer=RectifiedAdam(lr=1e-3, amsgrad=True), ) def test_sparse_sample_with_amsgrad(self): @@ -113,8 +109,7 @@ def test_sparse_sample_with_amsgrad(self): self.run_sparse_sample( iterations=2000, expected=[[-0.1929, 2.0], [3.0, 2.8074]], - lr=1e-3, - amsgrad=True, + optimizer=RectifiedAdam(lr=1e-3, amsgrad=True), ) def test_dense_sample_with_weight_decay(self): @@ -122,8 +117,7 @@ def test_dense_sample_with_weight_decay(self): self.run_dense_sample( iterations=1000, expected=[[0.5472, 1.5368], [2.5276, 3.5176]], - lr=1e-3, - weight_decay=0.01, + optimizer=RectifiedAdam(lr=1e-3, weight_decay=0.01), ) def test_sparse_sample_with_weight_decay(self): @@ -132,28 +126,31 @@ def test_sparse_sample_with_weight_decay(self): self.run_sparse_sample( iterations=2000, expected=[[-0.2029, 2.0], [3.0, 2.7380]], - lr=1e-3, - weight_decay=0.01, + optimizer=RectifiedAdam(lr=1e-3, weight_decay=0.01), ) def test_dense_sample_with_warmup(self): self.run_dense_sample( iterations=1000, expected=[[0.8041, 1.8041], [2.8041, 3.8041]], - lr=1e-3, - total_steps=1000, - warmup_proportion=0.1, - min_lr=1e-5, + optimizer=RectifiedAdam( + lr=1e-3, + total_steps=1000, + warmup_proportion=0.1, + min_lr=1e-5, + ), ) def test_sparse_sample_with_warmup(self): self.run_sparse_sample( iterations=2000, expected=[[0.4653, 2.0], [3.0, 3.4653]], - lr=1e-3, - total_steps=2000, - warmup_proportion=0.1, - min_lr=1e-5, + optimizer=RectifiedAdam( + lr=1e-3, + total_steps=2000, + warmup_proportion=0.1, + min_lr=1e-5, + ), ) def test_dense_sample_with_lookahead(self): @@ -162,10 +159,14 @@ def test_dense_sample_with_lookahead(self): self.run_dense_sample( iterations=1000, expected=[[0.7985, 1.7983], [2.7987, 3.7986]], - lr=1e-3, - beta_1=0.95, - lookahead_step=6, - lookahead_ratio=0.45, + optimizer=Lookahead( + RectifiedAdam( + lr=1e-3, + beta_1=0.95, + ), + k=6, + alpha=0.45, + ), ) def test_sparse_sample_with_lookahead(self): @@ -175,10 +176,14 @@ def test_sparse_sample_with_lookahead(self): self.run_sparse_sample( iterations=1500, expected=[[0.6417, 2.0], [3.0, 3.6418]], - lr=1e-3, - beta_1=0.95, - lookahead_step=6, - lookahead_ratio=0.45, + optimizer=Lookahead( + RectifiedAdam( + lr=1e-3, + beta_1=0.95, + ), + k=6, + alpha=0.45, + ), ) def test_get_config(self): From cba8e4a95f2ca4472cdd9145ddee66f90bf806a6 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Sat, 14 Sep 2019 15:32:59 +0800 Subject: [PATCH 06/13] Fix compatibility of Lookahead --- tensorflow_addons/optimizers/lookahead.py | 9 +++++---- tensorflow_addons/optimizers/lookahead_test.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index f15cc5b992..cf6b40ddb3 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -88,7 +88,7 @@ def _create_slots(self, var_list): def apply_gradients(self, grads_and_vars, name=None): var_list = [v for (_, v) in grads_and_vars] - with tf.keras.backend.name_scope(self._scope_ctx): + with tf.keras.backend.name_scope(self._name): with ops.init_scope(): _ = self.iterations self._create_hypers() @@ -102,10 +102,11 @@ def apply_gradients(self, grads_and_vars, name=None): init_op = self._init_op(var_list) with tf.control_dependencies([init_op]): - train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) + train_op = self._optimizer.apply_gradients( + grads_and_vars, name=name) with tf.control_dependencies([train_op]): - lookahead_updates = [self._look_ahead_op(var) for var in var_list] - lookahead_op = control_flow_ops.group(lookahead_updates) + lookahead_op = control_flow_ops.group([ + self._look_ahead_op(var) for var in var_list]) return control_flow_ops.group(init_op, train_op, lookahead_op) diff --git a/tensorflow_addons/optimizers/lookahead_test.py b/tensorflow_addons/optimizers/lookahead_test.py index ddabc30232..986b318e23 100644 --- a/tensorflow_addons/optimizers/lookahead_test.py +++ b/tensorflow_addons/optimizers/lookahead_test.py @@ -91,7 +91,7 @@ def run_sparse_sample(self, iterations, optimizer, seed=0x2019): def test_dense_exact_ratio(self): for k in [5, 10, 100, 500]: - for alpha in [0.1, 0.5, 0.9]: + for alpha in [0.1, 0.5, 0.8]: optimizer = optimizers.get('adam') vals, quick_vars = self.run_dense_sample(k, optimizer) optimizer = Lookahead('adam', k=k, alpha=alpha) From 3c3a0f544acc936fccda8dce49c04d239076c405 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Sat, 14 Sep 2019 18:45:06 +0800 Subject: [PATCH 07/13] Add test case for training a simple linear model with Lookahead --- tensorflow_addons/optimizers/lookahead.py | 32 +++++++++++-------- .../optimizers/lookahead_test.py | 20 ++++++++++++ 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index cf6b40ddb3..0edc33ee3c 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -112,14 +112,13 @@ def apply_gradients(self, grads_and_vars, name=None): def _init_op(self, var_list): updates = [] - iterations = self._optimizer.iterations for var in var_list: slow_var = self.get_slot(var, 'slow') updates.append(state_ops.assign( slow_var, tf.where( - math_ops.equal(iterations, - tf.constant(0, dtype=iterations.dtype)), + math_ops.equal(self.iterations, + tf.constant(0, dtype=self.iterations.dtype)), var, slow_var, ), @@ -129,23 +128,28 @@ def _init_op(self, var_list): def _look_ahead_op(self, var): var_dtype = var.dtype.base_dtype slow_var = self.get_slot(var, 'slow') - local_step = math_ops.cast(self._optimizer.iterations, var_dtype) + local_step = math_ops.cast(self.iterations, var_dtype) k = self._get_hyper('k', local_step.dtype) alpha = self._get_hyper('alpha', var_dtype) step_back = slow_var + alpha * (var - slow_var) sync_cond = math_ops.equal(local_step % k, 0) - slow_update = state_ops.assign(slow_var, tf.where( - sync_cond, - step_back, - slow_var, - ), use_locking=self._use_locking) - var_update = state_ops.assign(var, tf.where( - sync_cond, - step_back, - var, - ), use_locking=self._use_locking) + with tf.control_dependencies([step_back]): + slow_update = state_ops.assign(slow_var, tf.where( + sync_cond, + step_back, + slow_var, + ), use_locking=self._use_locking) + var_update = state_ops.assign(var, tf.where( + sync_cond, + step_back, + var, + ), use_locking=self._use_locking) return control_flow_ops.group(slow_update, var_update) + @property + def iterations(self): + return self._optimizer.iterations + @property def weights(self): return self._optimizer.weights diff --git a/tensorflow_addons/optimizers/lookahead_test.py b/tensorflow_addons/optimizers/lookahead_test.py index 986b318e23..2657928634 100644 --- a/tensorflow_addons/optimizers/lookahead_test.py +++ b/tensorflow_addons/optimizers/lookahead_test.py @@ -111,6 +111,26 @@ def test_sparse_exact_ratio(self): expected = val + (quick - val) * alpha self.assertAllClose(expected, slow) + def test_fit_simple_linear_model(self): + np.random.seed(0x2019) + + x = np.random.standard_normal((100000, 3)) + w = np.random.standard_normal((3, 1)) + y = np.dot(x, w) + np.random.standard_normal((100000, 1)) * 1e-4 + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + model.compile(Lookahead('adam'), loss='mse') + + model.fit(x, y, epochs=3) + + x = np.random.standard_normal((100, 3)) + y = np.dot(x, w) + predicted = model.predict(x) + + max_abs_diff = np.max(np.abs(predicted - y)) + self.assertLess(max_abs_diff, 1e-4) + def test_invalid_optimizer_type(self): with self.assertRaises(TypeError): Lookahead(optimizers.Adam()) From 24176af34965e753658d5a15b261a23bb0d59f19 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Sat, 14 Sep 2019 19:36:39 +0800 Subject: [PATCH 08/13] Fix Lookahead when executing eagerly --- tensorflow_addons/optimizers/lookahead.py | 81 +++++++++-------------- 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 0edc33ee3c..018c6b37d6 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -18,7 +18,6 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.python import ops from tensorflow.python.ops import math_ops, state_ops, control_flow_ops from tensorflow.python.keras import optimizers from tensorflow_addons.utils import keras_utils @@ -82,53 +81,36 @@ def __init__(self, self._initialized = False def _create_slots(self, var_list): + self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access for var in var_list: self.add_slot(var, 'slow') - def apply_gradients(self, grads_and_vars, name=None): - var_list = [v for (_, v) in grads_and_vars] - - with tf.keras.backend.name_scope(self._name): - with ops.init_scope(): - _ = self.iterations - self._create_hypers() - self._create_slots(var_list) - self._prepare(var_list) + def _create_hypers(self): + self._optimizer._create_hypers() # pylint: disable=protected-access - if self._initialized: - init_op = tf.no_op() - else: - self._initialized = True - init_op = self._init_op(var_list) - - with tf.control_dependencies([init_op]): - train_op = self._optimizer.apply_gradients( - grads_and_vars, name=name) - with tf.control_dependencies([train_op]): - lookahead_op = control_flow_ops.group([ - self._look_ahead_op(var) for var in var_list]) + def _prepare(self, var_list): + self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access - return control_flow_ops.group(init_op, train_op, lookahead_op) + def apply_gradients(self, grads_and_vars, name=None): + self._optimizer._iterations = self.iterations # pylint: disable=protected-access + return super(Lookahead, self).apply_gradients(grads_and_vars, name) - def _init_op(self, var_list): - updates = [] - for var in var_list: - slow_var = self.get_slot(var, 'slow') - updates.append(state_ops.assign( + def _init_op(self, var): + slow_var = self.get_slot(var, 'slow') + return state_ops.assign( + slow_var, + tf.where( + math_ops.equal(self.iterations, + tf.constant(0, dtype=self.iterations.dtype)), + var, slow_var, - tf.where( - math_ops.equal(self.iterations, - tf.constant(0, dtype=self.iterations.dtype)), - var, - slow_var, - ), - use_locking=self._use_locking)) - return control_flow_ops.group(*updates) + ), + use_locking=self._use_locking) def _look_ahead_op(self, var): var_dtype = var.dtype.base_dtype slow_var = self.get_slot(var, 'slow') - local_step = math_ops.cast(self.iterations, var_dtype) + local_step = math_ops.cast(self.iterations + 1, var_dtype) k = self._get_hyper('k', local_step.dtype) alpha = self._get_hyper('alpha', var_dtype) step_back = slow_var + alpha * (var - slow_var) @@ -146,23 +128,26 @@ def _look_ahead_op(self, var): ), use_locking=self._use_locking) return control_flow_ops.group(slow_update, var_update) - @property - def iterations(self): - return self._optimizer.iterations - @property def weights(self): - return self._optimizer.weights + return self._weights + self._optimizer.weights def _resource_apply_dense(self, grad, var): - return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access - - def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): - return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access - grad, var, indices) + init_op = self._init_op(var) + with tf.control_dependencies([init_op]): + train_op = self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access + with tf.control_dependencies([train_op]): + look_ahead_op = self._look_ahead_op(var) + return tf.group(init_op, train_op, look_ahead_op) def _resource_apply_sparse(self, grad, var, indices): - return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access + init_op = self._init_op(var) + with tf.control_dependencies([init_op]): + train_op = self._optimizer._resource_apply_sparse( # pylint: disable=protected-access + grad, var, indices) + with tf.control_dependencies([train_op]): + look_ahead_op = self._look_ahead_op(var) + return tf.group(init_op, train_op, look_ahead_op) def get_config(self): config = { From e627d75710724cbcbb93c9c4a42c3bfa14617492 Mon Sep 17 00:00:00 2001 From: Zhao Hanguang <853842+CyberZHG@users.noreply.github.com> Date: Sun, 15 Sep 2019 22:32:36 +0800 Subject: [PATCH 09/13] Fix orders and use public TensorFlow API --- tensorflow_addons/optimizers/BUILD | 22 ++--- tensorflow_addons/optimizers/README.md | 8 +- tensorflow_addons/optimizers/__init__.py | 4 +- tensorflow_addons/optimizers/lookahead.py | 50 +++++----- .../optimizers/lookahead_test.py | 34 +++---- .../optimizers/rectified_adam.py | 97 +++++++++---------- .../optimizers/rectified_adam_test.py | 8 +- 7 files changed, 106 insertions(+), 117 deletions(-) diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 44f0b5be93..f158f58a74 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -7,10 +7,10 @@ py_library( srcs = [ "__init__.py", "lazy_adam.py", + "lookahead.py", "moving_average.py", - "weight_decay_optimizers.py", "rectified_adam.py", - "lookahead.py", + "weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ @@ -32,12 +32,12 @@ py_test( ) py_test( - name = "moving_average_test", + name = "lookahead_test", size = "small", srcs = [ - "moving_average_test.py", + "lookahead_test.py", ], - main = "moving_average_test.py", + main = "lookahead_test.py", srcs_version = "PY2AND3", deps = [ ":optimizers", @@ -45,12 +45,12 @@ py_test( ) py_test( - name = "weight_decay_optimizers_test", + name = "moving_average_test", size = "small", srcs = [ - "weight_decay_optimizers_test.py", + "moving_average_test.py", ], - main = "weight_decay_optimizers_test.py", + main = "moving_average_test.py", srcs_version = "PY2AND3", deps = [ ":optimizers", @@ -71,12 +71,12 @@ py_test( ) py_test( - name = "lookahead_test", + name = "weight_decay_optimizers_test", size = "small", srcs = [ - "lookahead_test.py", + "weight_decay_optimizers_test.py", ], - main = "lookahead_test.py", + main = "weight_decay_optimizers_test.py", srcs_version = "PY2AND3", deps = [ ":optimizers", diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 7c194d8784..92c2d640c0 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -4,20 +4,20 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | lazy_adam | Saishruthi Swaminathan | saishruthi.tn@gmail.com | +| lookahead | Zhao Hanguang | cyberzhg@gmail.com | | moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | -| weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | | rectified_adam | Zhao Hanguang | cyberzhg@gmail.com | -| lookahead | Zhao Hanguang | cyberzhg@gmail.com | +| weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | ## Components | Submodule | Optimizer | Reference | |:--------- |:---------- |:---------| | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | +| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 | | moving_average | MovingAverage | | -| weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | | rectified_adam | RectifiedAdam | https://arxiv.org/pdf/1908.03265v1.pdf | -| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 | +| weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | ## Contribution Guidelines diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index e1a45cd430..b2dbc44cb8 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,10 +19,10 @@ from __future__ import print_function from tensorflow_addons.optimizers.lazy_adam import LazyAdam +from tensorflow_addons.optimizers.lookahead import Lookahead from tensorflow_addons.optimizers.moving_average import MovingAverage +from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW from tensorflow_addons.optimizers.weight_decay_optimizers import ( extend_with_decoupled_weight_decay) -from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam -from tensorflow_addons.optimizers.lookahead import Lookahead diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 018c6b37d6..75a3914525 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -18,8 +18,6 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.python.ops import math_ops, state_ops, control_flow_ops -from tensorflow.python.keras import optimizers from tensorflow_addons.utils import keras_utils @@ -46,8 +44,8 @@ class Lookahead(tf.keras.optimizers.Optimizer): def __init__(self, optimizer, - k=6, - alpha=0.5, + sync_period=6, + slow_step_size=0.5, name="Lookahead", **kwargs): r"""Wrap optimizer with the lookahead mechanism. @@ -55,9 +53,10 @@ def __init__(self, Args: optimizer: A Tensor or a floating point value. The learning rate. - k: An integer. Synchronization period of lookahead. + sync_period: An integer. The synchronization period of lookahead. Enable lookahead mechanism by setting it with a positive value. - alpha: A float value. Slow weights step size. + slow_step_size: A floating point value. + The ratio for updating the slow weights. name: Optional name for the operations created when applying gradients. Defaults to "RectifiedAdam". **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, @@ -70,14 +69,14 @@ def __init__(self, super(Lookahead, self).__init__(name, **kwargs) if isinstance(optimizer, str): - optimizer = optimizers.get(optimizer) + optimizer = tf.keras.optimizers.get(optimizer) if not isinstance(optimizer, tf.keras.optimizers.Optimizer): raise TypeError( "optimizer is not an object of tf.keras.optimizers.Optimizer") self._optimizer = optimizer - self._set_hyper('k', k) - self._set_hyper('alpha', alpha) + self._set_hyper('sync_period', sync_period) + self._set_hyper('slow_step_size', slow_step_size) self._initialized = False def _create_slots(self, var_list): @@ -89,7 +88,7 @@ def _create_hypers(self): self._optimizer._create_hypers() # pylint: disable=protected-access def _prepare(self, var_list): - self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access + return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access def apply_gradients(self, grads_and_vars, name=None): self._optimizer._iterations = self.iterations # pylint: disable=protected-access @@ -97,11 +96,10 @@ def apply_gradients(self, grads_and_vars, name=None): def _init_op(self, var): slow_var = self.get_slot(var, 'slow') - return state_ops.assign( - slow_var, + return slow_var.assign( tf.where( - math_ops.equal(self.iterations, - tf.constant(0, dtype=self.iterations.dtype)), + tf.equal(self.iterations, + tf.constant(0, dtype=self.iterations.dtype)), var, slow_var, ), @@ -110,23 +108,23 @@ def _init_op(self, var): def _look_ahead_op(self, var): var_dtype = var.dtype.base_dtype slow_var = self.get_slot(var, 'slow') - local_step = math_ops.cast(self.iterations + 1, var_dtype) - k = self._get_hyper('k', local_step.dtype) - alpha = self._get_hyper('alpha', var_dtype) - step_back = slow_var + alpha * (var - slow_var) - sync_cond = math_ops.equal(local_step % k, 0) + local_step = tf.cast(self.iterations + 1, var_dtype) + sync_period = self._get_hyper('sync_period', local_step.dtype) + slow_step_size = self._get_hyper('slow_step_size', var_dtype) + step_back = slow_var + slow_step_size * (var - slow_var) + sync_cond = tf.equal(local_step % sync_period, 0) with tf.control_dependencies([step_back]): - slow_update = state_ops.assign(slow_var, tf.where( + slow_update = slow_var.assign(tf.where( sync_cond, step_back, slow_var, ), use_locking=self._use_locking) - var_update = state_ops.assign(var, tf.where( + var_update = var.assign(tf.where( sync_cond, step_back, var, ), use_locking=self._use_locking) - return control_flow_ops.group(slow_update, var_update) + return tf.group(slow_update, var_update) @property def weights(self): @@ -151,16 +149,16 @@ def _resource_apply_sparse(self, grad, var, indices): def get_config(self): config = { - 'optimizer': optimizers.serialize(self._optimizer), - 'k': self._serialize_hyperparameter('k'), - 'alpha': self._serialize_hyperparameter('alpha'), + 'optimizer': tf.keras.optimizers.serialize(self._optimizer), + 'sync_period': self._serialize_hyperparameter('sync_period'), + 'slow_step_size': self._serialize_hyperparameter('slow_step_size'), } base_config = super(Lookahead, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): - optimizer = optimizers.deserialize( + optimizer = tf.keras.optimizers.deserialize( config.pop('optimizer'), custom_objects=custom_objects, ) diff --git a/tensorflow_addons/optimizers/lookahead_test.py b/tensorflow_addons/optimizers/lookahead_test.py index 2657928634..3aee621348 100644 --- a/tensorflow_addons/optimizers/lookahead_test.py +++ b/tensorflow_addons/optimizers/lookahead_test.py @@ -20,7 +20,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import optimizers from tensorflow_addons.utils import test_utils from tensorflow_addons.optimizers import Lookahead @@ -90,22 +89,26 @@ def run_sparse_sample(self, iterations, optimizer, seed=0x2019): return [val_0, val_1], [self.evaluate(var_0), self.evaluate(var_1)] def test_dense_exact_ratio(self): - for k in [5, 10, 100, 500]: - for alpha in [0.1, 0.5, 0.8]: - optimizer = optimizers.get('adam') + for k in [5, 10, 100]: + for alpha in [0.3, 0.7]: + optimizer = tf.keras.optimizers.get('adam') vals, quick_vars = self.run_dense_sample(k, optimizer) - optimizer = Lookahead('adam', k=k, alpha=alpha) + optimizer = Lookahead('adam', + sync_period=k, + slow_step_size=alpha) _, slow_vars = self.run_dense_sample(k, optimizer) for val, quick, slow in zip(vals, quick_vars, slow_vars): expected = val + (quick - val) * alpha self.assertAllClose(expected, slow) def test_sparse_exact_ratio(self): - for k in [5, 10, 100, 500]: - for alpha in [0.1, 0.5, 0.9]: - optimizer = optimizers.get('adam') + for k in [5, 10, 100]: + for alpha in [0.3, 0.7]: + optimizer = tf.keras.optimizers.get('adam') vals, quick_vars = self.run_sparse_sample(k, optimizer) - optimizer = Lookahead('adam', k=k, alpha=alpha) + optimizer = Lookahead('adam', + sync_period=k, + slow_step_size=alpha) _, slow_vars = self.run_sparse_sample(k, optimizer) for val, quick, slow in zip(vals, quick_vars, slow_vars): expected = val + (quick - val) * alpha @@ -131,16 +134,13 @@ def test_fit_simple_linear_model(self): max_abs_diff = np.max(np.abs(predicted - y)) self.assertLess(max_abs_diff, 1e-4) - def test_invalid_optimizer_type(self): - with self.assertRaises(TypeError): - Lookahead(optimizers.Adam()) - def test_get_config(self): - opt = Lookahead('adam', k=10, alpha=0.4) - opt = optimizers.deserialize(optimizers.serialize(opt)) + opt = Lookahead('adam', sync_period=10, slow_step_size=0.4) + opt = tf.keras.optimizers.deserialize( + tf.keras.optimizers.serialize(opt)) config = opt.get_config() - self.assertEqual(config['k'], 10) - self.assertEqual(config['alpha'], 0.4) + self.assertEqual(config['sync_period'], 10) + self.assertEqual(config['slow_step_size'], 0.4) if __name__ == '__main__': diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index 309de40052..c30ad08e2e 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -18,9 +18,6 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.python import ops -from tensorflow.python.ops import (math_ops, state_ops, - array_ops, control_flow_ops) from tensorflow_addons.utils import keras_utils @@ -67,7 +64,7 @@ class RectifiedAdam(tf.keras.optimizers.Optimizer): ```python radam = tfa.optimizers.RectifiedAdam() - ranger = tfa.optimizers.Lookahead(radam, k=6, alpha=0.5) + ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5) ``` """ @@ -150,22 +147,22 @@ def _resource_apply_dense(self, grad, var): v = self.get_slot(var, 'v') beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) - epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) - local_step = math_ops.cast(self.iterations + 1, var_dtype) - beta_1_power = math_ops.pow(beta_1_t, local_step) - beta_2_power = math_ops.pow(beta_2_t, local_step) + epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) warmup_steps = total_steps *\ self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) - decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) + decay_steps = tf.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), - lr_t + decay_rate * math_ops.minimum( + lr_t + decay_rate * tf.minimum( local_step - warmup_steps, decay_steps), ) @@ -173,27 +170,25 @@ def _resource_apply_dense(self, grad, var): sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) - m_t = state_ops.assign(m, - beta_1_t * m + (1.0 - beta_1_t) * grad, - use_locking=self._use_locking) + m_t = m.assign(beta_1_t * m + (1.0 - beta_1_t) * grad, + use_locking=self._use_locking) m_corr_t = m_t / (1.0 - beta_1_power) - v_t = state_ops.assign(v, - beta_2_t * v + - (1.0 - beta_2_t) * math_ops.square(grad), - use_locking=self._use_locking) + v_t = v.assign(beta_2_t * v + + (1.0 - beta_2_t) * tf.square(grad), + use_locking=self._use_locking) if self.amsgrad: vhat = self.get_slot(var, 'vhat') - vhat_t = state_ops.assign(vhat, - math_ops.maximum(vhat, v_t), - use_locking=self._use_locking) - v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power)) + vhat_t = vhat.assign(tf.maximum(vhat, v_t), + use_locking=self._use_locking) + v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power)) else: - v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power)) + vhat_t = None + v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power)) - r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * - (sma_t - 2.0) / (sma_inf - 2.0) * - sma_inf / sma_t) + r_t = tf.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * + (sma_t - 2.0) / (sma_inf - 2.0) * + sma_inf / sma_t) sma_threshold = self._get_hyper('sma_threshold', var_dtype) var_t = tf.where( @@ -204,36 +199,34 @@ def _resource_apply_dense(self, grad, var): if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var - var_update = state_ops.assign_sub(var, - lr_t * var_t, - use_locking=self._use_locking) + var_update = var.assign_sub(lr_t * var_t, use_locking=self._use_locking) updates = [var_update + m_t, v_t] if self.amsgrad: updates.append(vhat_t) - return control_flow_ops.group(*updates) + return tf.group(*updates) def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) - epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) - local_step = math_ops.cast(self.iterations + 1, var_dtype) - beta_1_power = math_ops.pow(beta_1_t, local_step) - beta_2_power = math_ops.pow(beta_2_t, local_step) + epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) warmup_steps = total_steps *\ self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) - decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) + decay_steps = tf.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), - lr_t + decay_rate * math_ops.minimum( + lr_t + decay_rate * tf.minimum( local_step - warmup_steps, decay_steps), ) @@ -243,29 +236,29 @@ def _resource_apply_sparse(self, grad, var, indices): m = self.get_slot(var, 'm') m_scaled_g_values = grad * (1 - beta_1_t) - m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) - with ops.control_dependencies([m_t]): + m_t = m.assign(m * beta_1_t, use_locking=self._use_locking) + with tf.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) m_corr_t = m_t / (1.0 - beta_1_power) v = self.get_slot(var, 'v') v_scaled_g_values = (grad * grad) * (1 - beta_2_t) - v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) - with ops.control_dependencies([v_t]): + v_t = v.assign(v * beta_2_t, use_locking=self._use_locking) + with tf.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) if self.amsgrad: vhat = self.get_slot(var, 'vhat') - vhat_t = state_ops.assign(vhat, - math_ops.maximum(vhat, v_t), - use_locking=self._use_locking) - v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power)) + vhat_t = vhat.assign(tf.maximum(vhat, v_t), + use_locking=self._use_locking) + v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power)) else: - v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power)) + vhat_t = None + v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power)) - r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * - (sma_t - 2.0) / (sma_inf - 2.0) * - sma_inf / sma_t) + r_t = tf.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * + (sma_t - 2.0) / (sma_inf - 2.0) * + sma_inf / sma_t) sma_threshold = self._get_hyper('sma_threshold', var_dtype) var_t = tf.where( @@ -276,18 +269,16 @@ def _resource_apply_sparse(self, grad, var, indices): if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var - var_t *= lr_t - with ops.control_dependencies([var_t]): - var_update = state_ops.scatter_sub( + with tf.control_dependencies([var_t]): + var_update = self._resource_scatter_add( var, indices, - array_ops.gather(var_t, indices), - use_locking=self._use_locking) + tf.gather(-lr_t * var_t, indices)) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(vhat_t) - return control_flow_ops.group(*updates) + return tf.group(*updates) def get_config(self): config = super(RectifiedAdam, self).get_config() diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index 3f0a2a3741..4308688fd1 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -164,8 +164,8 @@ def test_dense_sample_with_lookahead(self): lr=1e-3, beta_1=0.95, ), - k=6, - alpha=0.45, + sync_period=6, + slow_step_size=0.45, ), ) @@ -181,8 +181,8 @@ def test_sparse_sample_with_lookahead(self): lr=1e-3, beta_1=0.95, ), - k=6, - alpha=0.45, + sync_period=6, + slow_step_size=0.45, ), ) From bd74c538b6b928566f00baf96b366331e1702455 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Mon, 16 Sep 2019 22:25:56 +0800 Subject: [PATCH 10/13] Apply the formatting tool for RAdam and Lookahead --- tensorflow_addons/optimizers/lookahead.py | 36 ++--- .../optimizers/lookahead_test.py | 23 ++-- .../optimizers/rectified_adam.py | 123 +++++++++--------- .../optimizers/rectified_adam_test.py | 11 +- 4 files changed, 95 insertions(+), 98 deletions(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 75a3914525..2df7db756c 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -59,12 +59,12 @@ def __init__(self, The ratio for updating the slow weights. name: Optional name for the operations created when applying gradients. Defaults to "RectifiedAdam". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, - `lr`, `decay`}. `clipnorm` is clip gradients by norm; - `clipvalue` is clip gradients by value, `decay` is included for - backward compatibility to allow time inverse decay of learning - rate. `lr` is included for backward compatibility, recommended - to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients + by norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse + decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. """ super(Lookahead, self).__init__(name, **kwargs) @@ -114,16 +114,20 @@ def _look_ahead_op(self, var): step_back = slow_var + slow_step_size * (var - slow_var) sync_cond = tf.equal(local_step % sync_period, 0) with tf.control_dependencies([step_back]): - slow_update = slow_var.assign(tf.where( - sync_cond, - step_back, - slow_var, - ), use_locking=self._use_locking) - var_update = var.assign(tf.where( - sync_cond, - step_back, - var, - ), use_locking=self._use_locking) + slow_update = slow_var.assign( + tf.where( + sync_cond, + step_back, + slow_var, + ), + use_locking=self._use_locking) + var_update = var.assign( + tf.where( + sync_cond, + step_back, + var, + ), + use_locking=self._use_locking) return tf.group(slow_update, var_update) @property diff --git a/tensorflow_addons/optimizers/lookahead_test.py b/tensorflow_addons/optimizers/lookahead_test.py index 3aee621348..fe57ec1ef1 100644 --- a/tensorflow_addons/optimizers/lookahead_test.py +++ b/tensorflow_addons/optimizers/lookahead_test.py @@ -27,7 +27,6 @@ @test_utils.run_all_in_graph_and_eager_modes class LookaheadTest(tf.test.TestCase): - def run_dense_sample(self, iterations, optimizer, seed=0x2019): np.random.seed(seed) @@ -65,15 +64,11 @@ def run_sparse_sample(self, iterations, optimizer, seed=0x2019): var_1 = tf.Variable(val_1, dtype=tf.dtypes.float32) grad_0 = tf.IndexedSlices( - tf.constant([np.random.standard_normal()]), - tf.constant([0]), - tf.constant([2]) - ) + tf.constant([np.random.standard_normal()]), tf.constant([0]), + tf.constant([2])) grad_1 = tf.IndexedSlices( - tf.constant([np.random.standard_normal()]), - tf.constant([1]), - tf.constant([2]) - ) + tf.constant([np.random.standard_normal()]), tf.constant([1]), + tf.constant([2])) grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) @@ -93,9 +88,8 @@ def test_dense_exact_ratio(self): for alpha in [0.3, 0.7]: optimizer = tf.keras.optimizers.get('adam') vals, quick_vars = self.run_dense_sample(k, optimizer) - optimizer = Lookahead('adam', - sync_period=k, - slow_step_size=alpha) + optimizer = Lookahead( + 'adam', sync_period=k, slow_step_size=alpha) _, slow_vars = self.run_dense_sample(k, optimizer) for val, quick, slow in zip(vals, quick_vars, slow_vars): expected = val + (quick - val) * alpha @@ -106,9 +100,8 @@ def test_sparse_exact_ratio(self): for alpha in [0.3, 0.7]: optimizer = tf.keras.optimizers.get('adam') vals, quick_vars = self.run_sparse_sample(k, optimizer) - optimizer = Lookahead('adam', - sync_period=k, - slow_step_size=alpha) + optimizer = Lookahead( + 'adam', sync_period=k, slow_step_size=alpha) _, slow_vars = self.run_sparse_sample(k, optimizer) for val, quick, slow in zip(vals, quick_vars, slow_vars): expected = val + (quick - val) * alpha diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index c30ad08e2e..ff4d8b3bf4 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -36,7 +36,8 @@ class RectifiedAdam(tf.keras.optimizers.Optimizer): opt = tfa.optimizers.RectifiedAdam(lr=1e-3) ``` - Note: `amsgrad` is not described in the original paper. Use it with caution. + Note: `amsgrad` is not described in the original paper. Use it with + caution. RAdam is not a placement of the heuristic warmup, the settings should be kept if warmup has already been employed and tuned in the baseline method. @@ -58,8 +59,8 @@ class RectifiedAdam(tf.keras.optimizers.Optimizer): Lookahead, proposed by Michael R. Zhang et.al in the paper [Lookahead Optimizer: k steps forward, 1 step back] (https://arxiv.org/abs/1907.08610v1), can be integrated with RAdam, - which is announced by Less Wright and the new combined optimizer can also be - called "Ranger". The mechanism can be enabled by using the lookahead + which is announced by Less Wright and the new combined optimizer can also + be called "Ranger". The mechanism can be enabled by using the lookahead wrapper. For example: ```python @@ -92,9 +93,11 @@ def __init__(self, The exponential decay rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. weight_decay: A floating point value. Weight decay for each param. - amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm - from the paper "On the Convergence of Adam and beyond". - sma_threshold. A float value. The threshold for simple mean average. + amsgrad: boolean. Whether to apply AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and + beyond". + sma_threshold. A float value. + The threshold for simple mean average. total_steps: An integer. Total number of training steps. Enable warmup by setting a positive value. warmup_proportion: A floating point value. @@ -102,12 +105,12 @@ def __init__(self, min_lr: A floating point value. Minimum learning rate after warmup. name: Optional name for the operations created when applying gradients. Defaults to "RectifiedAdam". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, - `lr`, `decay`}. `clipnorm` is clip gradients by norm; - `clipvalue` is clip gradients by value, `decay` is included for - backward compatibility to allow time inverse decay of learning - rate. `lr` is included for backward compatibility, recommended - to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients + by norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse + decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. """ super(RectifiedAdam, self).__init__(name, **kwargs) self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) @@ -162,44 +165,43 @@ def _resource_apply_dense(self, grad, var): lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), - lr_t + decay_rate * tf.minimum( - local_step - warmup_steps, - decay_steps), + lr_t + decay_rate * tf.minimum(local_step - warmup_steps, + decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 - sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) + sma_t = sma_inf - 2.0 * local_step * beta_2_power / ( + 1.0 - beta_2_power) - m_t = m.assign(beta_1_t * m + (1.0 - beta_1_t) * grad, - use_locking=self._use_locking) + m_t = m.assign( + beta_1_t * m + (1.0 - beta_1_t) * grad, + use_locking=self._use_locking) m_corr_t = m_t / (1.0 - beta_1_power) - v_t = v.assign(beta_2_t * v + - (1.0 - beta_2_t) * tf.square(grad), - use_locking=self._use_locking) + v_t = v.assign( + beta_2_t * v + (1.0 - beta_2_t) * tf.square(grad), + use_locking=self._use_locking) if self.amsgrad: vhat = self.get_slot(var, 'vhat') - vhat_t = vhat.assign(tf.maximum(vhat, v_t), - use_locking=self._use_locking) + vhat_t = vhat.assign( + tf.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power)) else: vhat_t = None v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power)) - r_t = tf.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * - (sma_t - 2.0) / (sma_inf - 2.0) * - sma_inf / sma_t) + r_t = tf.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / + (sma_inf - 2.0) * sma_inf / sma_t) sma_threshold = self._get_hyper('sma_threshold', var_dtype) - var_t = tf.where( - sma_t >= sma_threshold, - r_t * m_corr_t / (v_corr_t + epsilon_t), - m_corr_t) + var_t = tf.where(sma_t >= sma_threshold, + r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var - var_update = var.assign_sub(lr_t * var_t, use_locking=self._use_locking) + var_update = var.assign_sub( + lr_t * var_t, use_locking=self._use_locking) updates = [var_update + m_t, v_t] if self.amsgrad: @@ -226,13 +228,13 @@ def _resource_apply_sparse(self, grad, var, indices): lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), - lr_t + decay_rate * tf.minimum( - local_step - warmup_steps, - decay_steps), + lr_t + decay_rate * tf.minimum(local_step - warmup_steps, + decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 - sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) + sma_t = sma_inf - 2.0 * local_step * beta_2_power / ( + 1.0 - beta_2_power) m = self.get_slot(var, 'm') m_scaled_g_values = grad * (1 - beta_1_t) @@ -249,31 +251,26 @@ def _resource_apply_sparse(self, grad, var, indices): if self.amsgrad: vhat = self.get_slot(var, 'vhat') - vhat_t = vhat.assign(tf.maximum(vhat, v_t), - use_locking=self._use_locking) + vhat_t = vhat.assign( + tf.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power)) else: vhat_t = None v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power)) - r_t = tf.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * - (sma_t - 2.0) / (sma_inf - 2.0) * - sma_inf / sma_t) + r_t = tf.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / + (sma_inf - 2.0) * sma_inf / sma_t) sma_threshold = self._get_hyper('sma_threshold', var_dtype) - var_t = tf.where( - sma_t >= sma_threshold, - r_t * m_corr_t / (v_corr_t + epsilon_t), - m_corr_t) + var_t = tf.where(sma_t >= sma_threshold, + r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var with tf.control_dependencies([var_t]): var_update = self._resource_scatter_add( - var, - indices, - tf.gather(-lr_t * var_t, indices)) + var, indices, tf.gather(-lr_t * var_t, indices)) updates = [var_update, m_t, v_t] if self.amsgrad: @@ -283,17 +280,27 @@ def _resource_apply_sparse(self, grad, var, indices): def get_config(self): config = super(RectifiedAdam, self).get_config() config.update({ - 'learning_rate': self._serialize_hyperparameter('learning_rate'), - 'beta_1': self._serialize_hyperparameter('beta_1'), - 'beta_2': self._serialize_hyperparameter('beta_2'), - 'decay': self._serialize_hyperparameter('decay'), - 'weight_decay': self._serialize_hyperparameter('weight_decay'), - 'sma_threshold': self._serialize_hyperparameter('sma_threshold'), - 'epsilon': self.epsilon, - 'amsgrad': self.amsgrad, - 'total_steps': self._serialize_hyperparameter('total_steps'), + 'learning_rate': + self._serialize_hyperparameter('learning_rate'), + 'beta_1': + self._serialize_hyperparameter('beta_1'), + 'beta_2': + self._serialize_hyperparameter('beta_2'), + 'decay': + self._serialize_hyperparameter('decay'), + 'weight_decay': + self._serialize_hyperparameter('weight_decay'), + 'sma_threshold': + self._serialize_hyperparameter('sma_threshold'), + 'epsilon': + self.epsilon, + 'amsgrad': + self.amsgrad, + 'total_steps': + self._serialize_hyperparameter('total_steps'), 'warmup_proportion': - self._serialize_hyperparameter('warmup_proportion'), - 'min_lr': self._serialize_hyperparameter('min_lr'), + self._serialize_hyperparameter('warmup_proportion'), + 'min_lr': + self._serialize_hyperparameter('min_lr'), }) return config diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index 4308688fd1..b8c2216fcc 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -26,7 +26,6 @@ @test_utils.run_all_in_graph_and_eager_modes class RectifiedAdamTest(tf.test.TestCase): - def run_dense_sample(self, iterations, expected, optimizer): var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) @@ -53,15 +52,9 @@ def run_sparse_sample(self, iterations, expected, optimizer): var_1 = tf.Variable([3.0, 4.0]) grad_0 = tf.IndexedSlices( - tf.constant([0.1]), - tf.constant([0]), - tf.constant([2]) - ) + tf.constant([0.1]), tf.constant([0]), tf.constant([2])) grad_1 = tf.IndexedSlices( - tf.constant([0.04]), - tf.constant([1]), - tf.constant([2]) - ) + tf.constant([0.04]), tf.constant([1]), tf.constant([2])) grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) From f2da716caa70597a8818630692ecc9cf06b36c40 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Mon, 16 Sep 2019 22:58:35 +0800 Subject: [PATCH 11/13] Use floordiv instead of mod in Lookahead --- tensorflow_addons/optimizers/lookahead.py | 8 +++++--- tensorflow_addons/optimizers/rectified_adam_test.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 2df7db756c..9dca155346 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -108,11 +108,13 @@ def _init_op(self, var): def _look_ahead_op(self, var): var_dtype = var.dtype.base_dtype slow_var = self.get_slot(var, 'slow') - local_step = tf.cast(self.iterations + 1, var_dtype) - sync_period = self._get_hyper('sync_period', local_step.dtype) + local_step = tf.cast(self.iterations + 1, tf.dtypes.int64) + sync_period = self._get_hyper('sync_period', tf.dtypes.int64) slow_step_size = self._get_hyper('slow_step_size', var_dtype) step_back = slow_var + slow_step_size * (var - slow_var) - sync_cond = tf.equal(local_step % sync_period, 0) + sync_cond = tf.equal( + tf.math.floordiv(local_step, sync_period) * sync_period, + local_step) with tf.control_dependencies([step_back]): slow_update = slow_var.assign( tf.where( diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index b8c2216fcc..e67bbc7be3 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -44,8 +44,8 @@ def run_dense_sample(self, iterations, expected, optimizer): for _ in range(iterations): self.evaluate(update) - self.assertAllClose(var_0.read_value(), expected[0], atol=1e-4) - self.assertAllClose(var_1.read_value(), expected[1], atol=1e-4) + self.assertAllClose(var_0.read_value(), expected[0], atol=2e-4) + self.assertAllClose(var_1.read_value(), expected[1], atol=2e-4) def run_sparse_sample(self, iterations, expected, optimizer): var_0 = tf.Variable([1.0, 2.0]) @@ -67,8 +67,8 @@ def run_sparse_sample(self, iterations, expected, optimizer): for _ in range(iterations): self.evaluate(update) - self.assertAllClose(var_0.read_value(), expected[0], atol=1e-4) - self.assertAllClose(var_1.read_value(), expected[1], atol=1e-4) + self.assertAllClose(var_0.read_value(), expected[0], atol=2e-4) + self.assertAllClose(var_1.read_value(), expected[1], atol=2e-4) def test_dense_sample(self): # Expected values are obtained from the official implementation From 8149ac957e330c1024750797db3e63ba27633267 Mon Sep 17 00:00:00 2001 From: Zhao HG <853842+CyberZHG@users.noreply.github.com> Date: Wed, 18 Sep 2019 00:50:07 +0800 Subject: [PATCH 12/13] Fix docstring for Lookahead --- tensorflow_addons/optimizers/lookahead.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 9dca155346..405d413059 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -51,8 +51,8 @@ def __init__(self, r"""Wrap optimizer with the lookahead mechanism. Args: - optimizer: A Tensor or a floating point value. - The learning rate. + optimizer: The original optimizer that will be used to compute + and apply the gradients. sync_period: An integer. The synchronization period of lookahead. Enable lookahead mechanism by setting it with a positive value. slow_step_size: A floating point value. From b4cea8189724f88663637daa1df82d035d97cba4 Mon Sep 17 00:00:00 2001 From: Zhao HG <853842+CyberZHG@users.noreply.github.com> Date: Wed, 18 Sep 2019 07:52:06 +0800 Subject: [PATCH 13/13] Fix docstring for Lookahead --- tensorflow_addons/optimizers/lookahead.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 405d413059..7dbb71c747 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -58,7 +58,7 @@ def __init__(self, slow_step_size: A floating point value. The ratio for updating the slow weights. name: Optional name for the operations created when applying - gradients. Defaults to "RectifiedAdam". + gradients. Defaults to "Lookahead". **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip gradients by value, `decay` is