diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 0fcb8088c5..a2da66c776 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "lazy_adam.py", + "moving_average.py", ], srcs_version = "PY2AND3", deps = [ @@ -26,3 +27,16 @@ py_test( ":optimizers", ], ) + +py_test( + name = "moving_average_test", + size = "small", + srcs = [ + "moving_average_test.py", + ], + main = "moving_average_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 0331e8c55c..8804ebd69f 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -4,11 +4,13 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | lazy_adam | SIG-Addons | addons@tensorflow.org | +| moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | ## Components | Submodule | Optimizer | Reference | |:----------------------- |:---------------------- |:---------| | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | +| moving_average | MovingAverage | | ## Contribution Guidelines diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index 543774e8c7..79bbcf04f5 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,3 +19,4 @@ from __future__ import print_function from tensorflow_addons.optimizers.lazy_adam import LazyAdam +from tensorflow_addons.optimizers.moving_average import MovingAverage diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py new file mode 100644 index 0000000000..4321f89e75 --- /dev/null +++ b/tensorflow_addons/optimizers/moving_average.py @@ -0,0 +1,134 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +class MovingAverage(tf.keras.optimizers.Optimizer): + """Optimizer that computes a moving average of the variables. + + Empirically it has been found that using the moving average of the trained + parameters of a deep network is better than using its trained parameters + directly. This optimizer allows you to compute this moving average and swap + the variables at save time so that any code outside of the training loop + will use by default the average values instead of the original ones. + + Example of usage: + + ```python + opt = tf.keras.optimizers.SGD(learning_rate) + opt = tfa.optimizers.MovingAverage(opt) + + ``` + """ + + def __init__(self, + optimizer, + average_decay=0.1, + num_updates=None, + sequential_update=True, + name="MovingAverage", + **kwargs): + + super(MovingAverage, self).__init__(name, **kwargs) + + if not isinstance(optimizer, tf.keras.optimizers.Optimizer): + raise TypeError( + "optimizer is not an object of tf.keras.optimizers.Optimizer") + + if num_updates is not None and not isinstance(num_updates, int): + raise TypeError("num_updates must be None or of integer type") + + if not isinstance(sequential_update, bool): + raise TypeError("sequential_update must be of bool type") + + self._optimizer = optimizer + + with tf.name_scope(name): + self._ema = tf.train.ExponentialMovingAverage( + average_decay, num_updates=num_updates) + + self._set_hyper("average_decay", average_decay) + self._num_updates = num_updates + self._sequential_update = sequential_update + self._init = True + + def apply_gradients(self, grads_and_vars, name=None): + var_list = [v for (_, v) in grads_and_vars] + + if tf.executing_eagerly() and self._init: + # this to ensure that var_list is registered initially + self._ema.apply(var_list) + self._init = False + + train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) + + if self._sequential_update: + with tf.control_dependencies([train_op]): + ma_op = self._ema.apply(var_list) + else: + ma_op = self._ema.apply(var_list) + + return tf.group(train_op, ma_op, name="train_with_avg") + + def get_config(self): + config = { + 'average_decay': self._serialize_hyperparameter('average_decay'), + 'num_updates': self._num_updates, + 'sequential_update': self._sequential_update + } + base_config = self._optimizer.get_config() + return dict(list(base_config.items()) + list(config.items())) + + def assign_average_vars(self, var_list): + """Update variables in var_list with the running mean of the variables. + + Example: + ```python + model = tf.Sequential([...]) + opt = tfa.optimizers.MovingAverage( + tf.keras.optimizers.SGD(lr=2.0), 0.5) + + model.compile(opt, ...) + model.fit(x, y, ...) + + # Update the weights to their mean before saving + opt.assign_average_vars(model.variables) + + model.save('model.h5') + ``` + """ + assign = tf.group([v.assign(self._ema.average(v)) for v in var_list]) + return assign + + @property + def weights(self): + return self._optimizer.weights + + def _resource_apply_dense(self, grad, var): + return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access + + def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): + return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access + grad, var, indices) + + def _resource_apply_sparse(self, grad, var, indices): + return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access diff --git a/tensorflow_addons/optimizers/moving_average_test.py b/tensorflow_addons/optimizers/moving_average_test.py new file mode 100644 index 0000000000..681703449a --- /dev/null +++ b/tensorflow_addons/optimizers/moving_average_test.py @@ -0,0 +1,134 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MovingAverage optimizers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_addons.optimizers import MovingAverage +from tensorflow_addons.utils import test_utils + + +class MovingAverageTest(tf.test.TestCase): + @test_utils.run_in_graph_and_eager_modes + def test_run(self): + for sequential_update in [True, False]: + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + + grads_and_vars = list(zip([grads0, grads1], [var0, var1])) + + opt = MovingAverage( + tf.keras.optimizers.SGD(lr=2.0), + average_decay=0.5, + sequential_update=sequential_update) + + if not tf.executing_eagerly(): + update = opt.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.evaluate(update) + self.evaluate(update) + else: + opt.apply_gradients(grads_and_vars) + opt.apply_gradients(grads_and_vars) + + self.assertAllClose(var0.read_value(), [0.6, 1.6]) + self.assertAllClose(var1.read_value(), [2.96, 3.96]) + + ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access + ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access + + if sequential_update: + self.assertAllClose(ema_var0.read_value(), [0.75, 1.75]) + self.assertAllClose(ema_var1.read_value(), [2.975, 3.975]) + + assign = opt.assign_average_vars([var0, var1]) + self.evaluate(assign) + + if sequential_update: + self.assertAllClose(var0.read_value(), [0.75, 1.75]) + self.assertAllClose(var1.read_value(), [2.975, 3.975]) + + perturb = tf.group([ + var0.assign_add([1.0, 1.0]), + var1.assign_add([2.0, 2.0]), + ema_var0.assign_add([3.0, 3.0]), + ema_var1.assign_add([4.0, 4.0]) + ]) + self.evaluate(perturb) + + if sequential_update: + self.assertAllClose(var0.read_value(), [1.75, 2.75]) + self.assertAllClose(var1.read_value(), [4.975, 5.975]) + self.assertAllClose(ema_var0.read_value(), [3.75, 4.75]) + self.assertAllClose(ema_var1.read_value(), [6.975, 7.975]) + + @test_utils.run_in_graph_and_eager_modes + def test_opt_failure(self): + base_opt = None + for sequential_update in [True, False]: + with self.assertRaises(TypeError): + MovingAverage(base_opt, 0.5, sequential_update) + + @test_utils.run_in_graph_and_eager_modes + def test_model_weights_update(self): + grad = tf.Variable([[0.1]]) + model = tf.keras.Sequential([ + tf.keras.layers.Dense( + 1, + kernel_initializer=tf.keras.initializers.Constant([[1.0]]), + use_bias=False) + ]) + model.build(input_shape=[1, 1]) + + opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5) + update = opt.apply_gradients(list(zip([grad], model.variables))) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.evaluate(update) + self.assertAllClose(model.variables[0].read_value(), [[0.8]]) + + mean_update = opt.assign_average_vars(model.variables) + self.evaluate(mean_update) + self.assertAllClose(model.variables[0].read_value(), [[0.9]]) + + @test_utils.run_in_graph_and_eager_modes + def test_config(self): + sgd_opt = tf.keras.optimizers.SGD( + lr=2.0, nesterov=True, momentum=0.3, decay=0.1) + opt = MovingAverage( + sgd_opt, + average_decay=0.5, + num_updates=100, + sequential_update=False) + config = opt.get_config() + + self.assertEqual(config['average_decay'], 0.5) + self.assertEqual(config['decay'], 0.1) + self.assertEqual(config['learning_rate'], 2.0) + self.assertEqual(config['momentum'], 0.3) + self.assertEqual(config['name'], 'SGD') + self.assertEqual(config['nesterov'], True) + self.assertEqual(config['num_updates'], 100) + self.assertEqual(config['sequential_update'], False) + + +if __name__ == '__main__': + tf.test.main()