diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 00a31f327e..87599d8654 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -12,6 +12,7 @@ py_library( "lazy_adam.py", "lookahead.py", "moving_average.py", + "novograd.py", "rectified_adam.py", "stochastic_weight_averaging.py", "weight_decay_optimizers.py", @@ -106,6 +107,18 @@ py_test( ], ) +py_test( + name = "novograd_test", + size = "small", + srcs = [ + "novograd_test.py", + ], + main = "novograd_test.py", + deps = [ + ":optimizers", + ], +) + py_test( name = "rectified_adam_test", size = "small", diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index c73e49b2ed..a540be4bd5 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -9,6 +9,7 @@ | lazy_adam | Saishruthi Swaminathan | saishruthi.tn@gmail.com | | lookahead | Zhao Hanguang | cyberzhg@gmail.com | | moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | +| novograd | Shreyash Patodia | patodiashreyash32@gmail.com | | rectified_adam | Zhao Hanguang | cyberzhg@gmail.com | | stochastic_weight_averaging | Shreyash Patodia | patodiashreyash32@gmail.com | | weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | @@ -25,6 +26,7 @@ | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | | lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 | | moving_average | MovingAverage | | +| novograd | NovoGrad | https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html | | rectified_adam | RectifiedAdam | https://arxiv.org/pdf/1908.03265v1.pdf | | stochastic_weight_averaging | SWA | https://arxiv.org/abs/1803.05407.pdf | | weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index 2deaf5ee66..a4f49a0ea1 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -31,6 +31,7 @@ from tensorflow_addons.optimizers.lazy_adam import LazyAdam from tensorflow_addons.optimizers.lookahead import Lookahead from tensorflow_addons.optimizers.moving_average import MovingAverage +from tensorflow_addons.optimizers.novograd import NovoGrad from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam from tensorflow_addons.optimizers.stochastic_weight_averaging import SWA from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW diff --git a/tensorflow_addons/optimizers/novograd.py b/tensorflow_addons/optimizers/novograd.py new file mode 100644 index 0000000000..238faf678f --- /dev/null +++ b/tensorflow_addons/optimizers/novograd.py @@ -0,0 +1,246 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""NovoGrad for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +# TODO: Find public API alternatives to these +from tensorflow.python.training import training_ops + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class NovoGrad(tf.keras.optimizers.Optimizer): + """The NovoGrad Optimizer was first proposed in [Stochastic Gradient + Methods with Layerwise Adaptvie Moments for training of Deep + Networks](https://arxiv.org/pdf/1905.11286.pdf) + + NovoGrad is a first-order SGD-based algorithm, which computes second + moments per layer instead of per weight as in Adam. Compared to Adam, + NovoGrad takes less memory, and has been found to be more numerically + stable. More specifically we compute (for more information on the + computation please refer to this + [link](https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html): + + Second order moment = exponential moving average of Layer-wise square + of grads: + v_t <-- beta_2 * v_{t-1} + (1-beta_2) * (g_t)^2 + First order moment in one of four modes: + 1. moment of grads normalized by v_t: + m_t <- beta_1 * m_{t-1} + [ g_t / (sqrt(v_t)+epsilon)] + 2. moment similar to Adam: exponential moving average of grads + normalized by v_t (set grad_averaging = True to use this): + m_t <- beta_1 * m_{t-1} + + [(1 - beta_1) * (g_t / (sqrt(v_t) + epsilon))] + 3. weight decay adds a w_d term after grads are rescaled by + 1/sqrt(v_t) (set weight_decay > 0 to use this0: + m_t <- beta_1 * m_{t-1} + + [(g_t / (sqrt(v_t) + epsilon)) + (w_d * w_{t-1})] + 4. weight decay + exponential moving average from Adam: + m_t <- beta_1 * m_{t-1} + + [(1 - beta_1) * ((g_t / (sqrt(v_t + epsilon)) + + (w_d * w_{t-1}))] + Weight update: + w_t <- w_{t-1} - lr_t * m_t + + Example of usage: + ```python + opt = tfa.optimizers.NovoGrad( + lr=1e-3, + beta_1=0.9, + beta_2=0.999, + weight_decay=0.001, + grad_averaging=False + ) + ``` + """ + + def __init__(self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=0.0, + grad_averaging=False, + amsgrad=False, + name='NovoGrad', + **kwargs): + r"""Construct a new NovoGrad optimizer. + + Args: + learning_rate: A `Tensor` or a floating point value. or a schedule + that is a `tf.keras.optimizers.schedules.LearningRateSchedule` + The learning rate. + beta_1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. + weight_decay: A floating point value. Weight decay for each param. + grad_averaging: determines whether to use Adam style exponential + moving averaging for the first order moments. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients + by norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse + decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + super(NovoGrad, self).__init__(name, **kwargs) + if weight_decay < 0.0: + raise ValueError('Weight decay rate cannot be negative') + self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) + self._set_hyper('decay', self._initial_decay) + self._set_hyper('beta_1', beta_1) + self._set_hyper('beta_2', beta_2) + self._set_hyper('weight_decay', weight_decay) + self._set_hyper('grad_averaging', grad_averaging) + self.amsgrad = amsgrad + self.epsilon = epsilon or tf.keras.backend.epsilon() + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + # Separate for-loops to respect the ordering of slot variables from v1. + for var in var_list: + self.add_slot(var=var, slot_name='m', initializer='zeros') + for var in var_list: + self.add_slot( + var=var, + slot_name='v', + initializer=tf.zeros(shape=[], dtype=var.dtype)) + if self.amsgrad: + for var in var_list: + self.add_slot(var, 'vhat') + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(NovoGrad, self)._prepare_local(var_device, var_dtype, + apply_state) + beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) + beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) + apply_state[(var_device, var_dtype)].update( + dict( + epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), + beta_1_t=beta_1_t, + beta_2_t=beta_2_t, + one_minus_beta_2_t=1 - beta_2_t, + one_minus_beta_1_t=1 - beta_1_t, + )) + + def set_weights(self, weights): + params = self.weights + # If the weights are generated by Keras V1 optimizer, it includes vhats + # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2 + # optimizer has 2x + 1 variables. Filter vhats out for compatibility. + num_vars = int((len(params) - 1) / 2) + if len(weights) == 3 * num_vars + 1: + weights = weights[:len(params)] + super(NovoGrad, self).set_weights(weights) + + def _resource_apply_dense(self, grad, var, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) + or self._fallback_apply_state(var_device, var_dtype)) + weight_decay = self._get_hyper('weight_decay') + grad_averaging = self._get_hyper('grad_averaging') + + v = self.get_slot(var, 'v') + g_2 = tf.reduce_sum(tf.square(tf.cast(grad, tf.float32))) + v_t = tf.cond( + tf.equal(self.iterations, + 0), lambda: g_2, lambda: v * coefficients['beta_2_t'] + + g_2 * coefficients['one_minus_beta_2_t']) + v_t = v.assign(v_t, use_locking=self._use_locking) + + if self.amsgrad: + vhat = self.get_slot(var, 'vhat') + vhat_t = vhat.assign( + tf.maximum(vhat, v_t), use_locking=self._use_locking) + grad = grad / (tf.sqrt(vhat_t) + self.epsilon) + else: + grad = grad / (tf.sqrt(v_t) + self.epsilon) + grad = tf.cond( + tf.greater(weight_decay, + 0), lambda: grad + weight_decay * var, lambda: grad) + grad = tf.cond( + tf.logical_and(grad_averaging, tf.not_equal(self.iterations, 0)), + lambda: grad * coefficients['one_minus_beta_1_t'], lambda: grad) + m = self.get_slot(var, 'm') + return training_ops.resource_apply_keras_momentum( + var.handle, + m.handle, + coefficients['lr_t'], + grad, + coefficients['beta_1_t'], + use_locking=self._use_locking, + use_nesterov=False) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) + or self._fallback_apply_state(var_device, var_dtype)) + weight_decay = self._get_hyper('weight_decay') + grad_averaging = self._get_hyper('grad_averaging') + + v = self.get_slot(var, 'v') + g_2 = tf.reduce_sum(tf.square(tf.cast(grad, tf.float32))) + # v is just a scalar and does not need to involve sparse tensors. + v_t = tf.cond( + tf.equal(self.iterations, + 0), lambda: g_2, lambda: v * coefficients['beta_2_t'] + + g_2 * coefficients['one_minus_beta_2_t']) + v_t = v.assign(v_t, use_locking=self._use_locking) + + if self.amsgrad: + vhat = self.get_slot(var, 'vhat') + vhat_t = vhat.assign( + tf.maximum(vhat, v_t), use_locking=self._use_locking) + grad = grad / (tf.sqrt(vhat_t) + self.epsilon) + else: + grad = grad / (tf.sqrt(v_t) + self.epsilon) + grad = tf.cond( + tf.greater(weight_decay, + 0), lambda: grad + weight_decay * var, lambda: grad) + grad = tf.cond( + tf.logical_and(grad_averaging, tf.not_equal(self.iterations, 0)), + lambda: grad * coefficients['one_minus_beta_1_t'], lambda: grad) + m = self.get_slot(var, 'm') + return training_ops.resource_sparse_apply_keras_momentum( + var.handle, + m.handle, + coefficients['lr_t'], + tf.gather(grad, indices), + indices, + coefficients['beta_1_t'], + use_locking=self._use_locking, + use_nesterov=False) + + def get_config(self): + config = super(NovoGrad, self).get_config() + config.update({ + 'learning_rate': + self._serialize_hyperparameter('learning_rate'), + 'beta_1': + self._serialize_hyperparameter('beta_1'), + 'beta_2': + self._serialize_hyperparameter('beta_2'), + 'epsilon': + self.epsilon, + 'weight_decay': + self._serialize_hyperparameter('weight_decay'), + 'grad_averaging': + self._serialize_hyperparameter('grad_averaging'), + }) + return config diff --git a/tensorflow_addons/optimizers/novograd_test.py b/tensorflow_addons/optimizers/novograd_test.py new file mode 100644 index 0000000000..7f2009158d --- /dev/null +++ b/tensorflow_addons/optimizers/novograd_test.py @@ -0,0 +1,152 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for NovoGrad Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.optimizers import NovoGrad +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class NovoGradTest(tf.test.TestCase): + def run_dense_sample(self, iterations, expected, optimizer): + var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) + var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) + + grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32) + grad_1 = tf.constant([0.3, 0.4], dtype=tf.dtypes.float32) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + if tf.executing_eagerly(): + for _ in range(iterations): + optimizer.apply_gradients(grads_and_vars) + else: + update = optimizer.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + for _ in range(iterations): + self.evaluate(update) + + self.assertAllClose(var_0.read_value(), expected[0], atol=2e-4) + self.assertAllClose(var_1.read_value(), expected[1], atol=2e-4) + + def run_sparse_sample(self, iterations, expected, optimizer): + var_0 = tf.Variable([1.0, 2.0]) + var_1 = tf.Variable([3.0, 4.0]) + + grad_0 = tf.IndexedSlices( + tf.constant([0.1, 0.2]), tf.constant([0, 1]), tf.constant([2])) + grad_1 = tf.IndexedSlices( + tf.constant([0.3, 0.4]), tf.constant([0, 1]), tf.constant([2])) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + if tf.executing_eagerly(): + for _ in range(iterations): + optimizer.apply_gradients(grads_and_vars) + else: + update = optimizer.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + for _ in range(iterations): + self.evaluate(update) + + self.assertAllClose(var_0.read_value(), expected[0], atol=2e-4) + self.assertAllClose(var_1.read_value(), expected[1], atol=2e-4) + + def test_dense_sample(self): + self.run_dense_sample( + iterations=1, + expected=[[0.9552786425, 1.9105572849], + [2.9400000012, 3.9200000016]], + optimizer=NovoGrad(lr=0.1, epsilon=1e-8), + ) + + def test_sparse_sample(self): + self.run_sparse_sample( + iterations=1, + expected=[[0.9552786425, 1.9105572849], + [2.9400000012, 3.9200000016]], + optimizer=NovoGrad(lr=0.1, epsilon=1e-8), + ) + + def test_dense_sample_with_weight_decay(self): + self.run_dense_sample( + iterations=1, + expected=[[0.945278642, 1.8905572849], + [2.9100000012, 3.8800000016]], + optimizer=NovoGrad(lr=0.1, weight_decay=0.1, epsilon=1e-8), + ) + + def test_sparse_sample_with_weight_decay(self): + self.run_sparse_sample( + iterations=1, + expected=[[0.945278642, 1.8905572849], + [2.9100000012, 3.8800000016]], + optimizer=NovoGrad(lr=0.1, weight_decay=0.1, epsilon=1e-8), + ) + + def test_dense_sample_with_grad_averaging(self): + self.run_dense_sample( + iterations=2, + expected=[[0.9105572849, 1.8211145698], + [2.8800000024, 3.8400000032]], + optimizer=NovoGrad(lr=0.1, grad_averaging=True, epsilon=1e-8), + ) + + def test_sparse_sample_with_grad_averaging(self): + self.run_sparse_sample( + iterations=2, + expected=[[0.9105572849, 1.8211145698], + [2.8800000024, 3.8400000032]], + optimizer=NovoGrad(lr=0.1, grad_averaging=True, epsilon=1e-8), + ) + + def test_fit_simple_linear_model(self): + np.random.seed(0x2020) + tf.random.set_seed(0x2020) + + x = np.random.standard_normal((100000, 3)) + w = np.random.standard_normal((3, 1)) + y = np.dot(x, w) + np.random.standard_normal((100000, 1)) * 1e-5 + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + model.compile(NovoGrad(), loss='mse') + + model.fit(x, y, epochs=10) + + x = np.random.standard_normal((100, 3)) + y = np.dot(x, w) + predicted = model.predict(x) + + max_abs_diff = np.max(np.abs(predicted - y)) + self.assertLess(max_abs_diff, 1e-2) + + def test_get_config(self): + opt = NovoGrad(lr=1e-4, weight_decay=0.0, grad_averaging=False) + config = opt.get_config() + self.assertEqual(config['learning_rate'], 1e-4) + self.assertEqual(config['weight_decay'], 0.0) + self.assertEqual(config['grad_averaging'], False) + + +if __name__ == '__main__': + tf.test.main()