From 47c9654f44872f22937e6d2c40825a4268f9f6c0 Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Sun, 28 Apr 2019 00:10:58 +0530 Subject: [PATCH 1/6] Implement MovingAverage optimizer * Port MovingAverageOptimizer from tf.contrib.opt * Inherits base Keras optimizer_v2 * `swapping_saver` replaced with `assign_average_vars` * Update test cases for TF2.X * Update docs --- tensorflow_addons/optimizers/BUILD | 1 + tensorflow_addons/optimizers/README.md | 2 + tensorflow_addons/optimizers/__init__.py | 1 + .../optimizers/moving_average.py | 127 ++++++++++++++++++ .../optimizers/moving_average_test.py | 118 ++++++++++++++++ 5 files changed, 249 insertions(+) create mode 100644 tensorflow_addons/optimizers/moving_average.py create mode 100644 tensorflow_addons/optimizers/moving_average_test.py diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 0fcb8088c5..ea2b959356 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "lazy_adam.py", + "moving_average.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 0331e8c55c..8804ebd69f 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -4,11 +4,13 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | lazy_adam | SIG-Addons | addons@tensorflow.org | +| moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | ## Components | Submodule | Optimizer | Reference | |:----------------------- |:---------------------- |:---------| | lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 | +| moving_average | MovingAverage | | ## Contribution Guidelines diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index 543774e8c7..79bbcf04f5 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,3 +19,4 @@ from __future__ import print_function from tensorflow_addons.optimizers.lazy_adam import LazyAdam +from tensorflow_addons.optimizers.moving_average import MovingAverage diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py new file mode 100644 index 0000000000..36c4cad8e1 --- /dev/null +++ b/tensorflow_addons/optimizers/moving_average.py @@ -0,0 +1,127 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +class MovingAverage(tf.keras.optimizers.Optimizer): + """Optimizer that computes a moving average of the variables. + + Empirically it has been found that using the moving average of the trained + parameters of a deep network is better than using its trained parameters + directly. This optimizer allows you to compute this moving average and swap + the variables at save time so that any code outside of the training loop + will use by default the average values instead of the original ones. + + Example of usage: + + ```python + opt = tf.keras.optimizers.SGD(learning_rate) + opt = tfa.optimizers.MovingAverage(opt) + + ``` + + """ + + def __init__(self, + optimizer, + average_decay=0.1, + num_updates=None, + seq_update=True, + name="MovingAverage", + **kwargs): + + super(MovingAverage, self).__init__(name, **kwargs) + + if not isinstance(optimizer, tf.keras.optimizers.Optimizer): + raise TypeError( + "optimzer is not an object of tf.keras.optimizers.Optimizer") + + self._optimizer = optimizer + + with tf.keras.backend.name_scope(self.__class__.__name__): + self._ema = tf.train.ExponentialMovingAverage( + average_decay, num_updates=num_updates) + + self._average_decay = average_decay + self._num_updates = num_updates + self._seq_update = seq_update + + def _create_slots(self, var_list): + self._optimizer._create_slots(var_list) # pylint: disable=protected-access + + def _resource_apply_dense(self, grad, var): + return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access + + def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): + return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access + grad, var, indices) + + def _resource_apply_sparse(self, grad, var, indices): + return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access + + def apply_gradients(self, grads_and_vars, name=None): + # pop = tf.print(grads_and_vars) + train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) + var_list = [v for (_, v) in grads_and_vars] + + if self._seq_update: + with tf.control_dependencies([train_op]): + ma_op = self._ema.apply(var_list) + else: + ma_op = self._ema.apply(var_list) + + return tf.group(train_op, ma_op, name="train_with_avg") + + def get_config(self): + config = { + 'average_decay': self._average_decay, + 'num_updates': self._num_updates, + 'seq_update': self._seq_update + } + base_config = self._optimizer.get_config() + return dict(list(base_config.items()) + list(config.items())) + + def assign_average_vars(self, var_list): + """Update variables in var_list with the running mean of the variables. + + Example: + ```python + model = tf.Sequential([...]) + opt = tfa.optimizers.MovingAverage( + tf.keras.optimizers.SGD(lr=2.0), 0.5) + + model.compile(opt, ...) + model.fit(x, y, ...) + + # Update the weights to their mean before saving + opt.assign_average_vars(model.variables) + + model.save('model.h5') + + ``` + """ + assign = tf.group([v.assign(self._ema.average(v)) for v in var_list]) + return assign + + @property + def weights(self): + return self._optimizer.weights diff --git a/tensorflow_addons/optimizers/moving_average_test.py b/tensorflow_addons/optimizers/moving_average_test.py new file mode 100644 index 0000000000..d839b02faf --- /dev/null +++ b/tensorflow_addons/optimizers/moving_average_test.py @@ -0,0 +1,118 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MovingAverage optimizers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +import tensorflow as tf + +import moving_average +from tensorflow_addons.utils import test_utils + + +class MovingAverageTest(tf.test.TestCase): + @test_utils.run_deprecated_v1 + def test_run(self): + for seq_update in [True, False]: + orig_var0 = [1.0, 2.0] + orig_var1 = [3.0, 4.0] + + var0 = tf.Variable(orig_var0) + var1 = tf.Variable(orig_var1) + + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + + opt = moving_average.MovingAverage( + tf.keras.optimizers.SGD(lr=2.0), + average_decay=0.5, + seq_update=seq_update) + + update = opt.apply_gradients( + list(six.moves.zip([grads0, grads1], [var0, var1]))) + + ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access + ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access + + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.evaluate(update) + + self.assertAllClose(var0.read_value(), [0.8, 1.8]) + self.assertAllClose(var1.read_value(), [2.98, 3.98]) + + if seq_update: + self.assertAllClose(ema_var0.read_value(), [0.9, 1.9]) + self.assertAllClose(ema_var1.read_value(), [2.99, 3.99]) + + assign = opt.assign_average_vars([var0, var1]) + self.evaluate(assign) + + if seq_update: + self.assertAllClose(self.evaluate(var0), [0.9, 1.9]) + self.assertAllClose(self.evaluate(var1), [2.99, 3.99]) + + perturb = tf.group([ + var0.assign_add([1.0, 1.0]), + var1.assign_add([2.0, 2.0]), + ema_var0.assign_add([3.0, 3.0]), + ema_var1.assign_add([4.0, 4.0]) + ]) + self.evaluate(perturb) + + if seq_update: + self.assertAllClose(self.evaluate(var0), [1.9, 2.9]) + self.assertAllClose(self.evaluate(var1), [4.99, 5.99]) + self.assertAllClose(self.evaluate(ema_var0), [3.9, 4.9]) + self.assertAllClose(self.evaluate(ema_var1), [6.99, 7.99]) + + @test_utils.run_in_graph_and_eager_modes + def test_opt_failure(self): + base_opt = None + for seq_update in [True, False]: + with self.assertRaises(TypeError): + moving_average.MovingAverage(base_opt, 0.5, seq_update) + + @test_utils.run_deprecated_v1 + def test_model_weights_update(self): + grad = tf.Variable([[0.1]]) + model = tf.keras.Sequential([ + tf.keras.layers.Dense( + 1, + kernel_initializer=tf.keras.initializers.Constant([[1.0]]), + use_bias=False) + ]) + + model.build(input_shape=[1, 1]) + + opt = moving_average.MovingAverage( + tf.keras.optimizers.SGD(lr=2.0), 0.5) + + update = opt.apply_gradients( + list(six.moves.zip([grad], model.variables))) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.evaluate(update) + self.assertAllClose(model.variables[0].read_value(), [[0.8]]) + + mean_update = opt.assign_average_vars(model.variables) + self.evaluate(mean_update) + self.assertAllClose(model.variables[0].read_value(), [[0.9]]) + + +if __name__ == '__main__': + tf.test.main() From 25d33f40500ffd357c820a11590500fdd7507450 Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Sun, 28 Apr 2019 00:25:21 +0530 Subject: [PATCH 2/6] Add moving_average_test as a py_test in BUILD file --- tensorflow_addons/optimizers/BUILD | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index ea2b959356..a2da66c776 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -27,3 +27,16 @@ py_test( ":optimizers", ], ) + +py_test( + name = "moving_average_test", + size = "small", + srcs = [ + "moving_average_test.py", + ], + main = "moving_average_test.py", + srcs_version = "PY2AND3", + deps = [ + ":optimizers", + ], +) From d2420c901cf0203038a2ed9d69fd8e1a3e92a72a Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Mon, 29 Apr 2019 02:01:54 +0530 Subject: [PATCH 3/6] Move internal functions under external functions to improve readability --- .../optimizers/moving_average.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py index 36c4cad8e1..0d76b69b4b 100644 --- a/tensorflow_addons/optimizers/moving_average.py +++ b/tensorflow_addons/optimizers/moving_average.py @@ -65,19 +65,6 @@ def __init__(self, self._num_updates = num_updates self._seq_update = seq_update - def _create_slots(self, var_list): - self._optimizer._create_slots(var_list) # pylint: disable=protected-access - - def _resource_apply_dense(self, grad, var): - return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access - - def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): - return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access - grad, var, indices) - - def _resource_apply_sparse(self, grad, var, indices): - return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access - def apply_gradients(self, grads_and_vars, name=None): # pop = tf.print(grads_and_vars) train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) @@ -125,3 +112,16 @@ def assign_average_vars(self, var_list): @property def weights(self): return self._optimizer.weights + + def _create_slots(self, var_list): + self._optimizer._create_slots(var_list) # pylint: disable=protected-access + + def _resource_apply_dense(self, grad, var): + return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access + + def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): + return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access + grad, var, indices) + + def _resource_apply_sparse(self, grad, var, indices): + return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access From b8da9a6c39f698f37fc0d94623a82e08c17eec3d Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Mon, 29 Apr 2019 17:56:28 +0530 Subject: [PATCH 4/6] Refactor code and add test for config * Use _set_hyper() and _get_hyper() instead of member variables for average_decay, num_updates and sequential_update * Remove _create_slots() from MovingAverage * Use _serialize_hyperparameter() in get_config() * Replace if-else with tf.cond() to work with tensors * Use absolute import of tensorflow_addons in moving_average_test.py --- .../optimizers/moving_average.py | 40 ++++++++++-------- .../optimizers/moving_average_test.py | 41 ++++++++++++++----- 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py index 0d76b69b4b..08611cad50 100644 --- a/tensorflow_addons/optimizers/moving_average.py +++ b/tensorflow_addons/optimizers/moving_average.py @@ -38,14 +38,13 @@ class MovingAverage(tf.keras.optimizers.Optimizer): opt = tfa.optimizers.MovingAverage(opt) ``` - """ def __init__(self, optimizer, average_decay=0.1, num_updates=None, - seq_update=True, + sequential_update=True, name="MovingAverage", **kwargs): @@ -57,32 +56,41 @@ def __init__(self, self._optimizer = optimizer - with tf.keras.backend.name_scope(self.__class__.__name__): + # NoneType cannot be passed to _set_hyper, so we convert it to -1 + # and vice-versa when creating the object using from_config + num_updates = None if num_updates == -1 else num_updates + with tf.name_scope(name): self._ema = tf.train.ExponentialMovingAverage( average_decay, num_updates=num_updates) + num_updates = -1 if num_updates is None else num_updates - self._average_decay = average_decay - self._num_updates = num_updates - self._seq_update = seq_update + self._set_hyper("average_decay", average_decay) + self._set_hyper("num_updates", num_updates) + self._set_hyper("sequential_update", sequential_update) def apply_gradients(self, grads_and_vars, name=None): - # pop = tf.print(grads_and_vars) train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) var_list = [v for (_, v) in grads_and_vars] + sequential_update = self._get_hyper("sequential_update", tf.bool) - if self._seq_update: + def true_fn(): with tf.control_dependencies([train_op]): - ma_op = self._ema.apply(var_list) - else: - ma_op = self._ema.apply(var_list) + return self._ema.apply(var_list) + + def false_fn(): + return self._ema.apply(var_list) + ma_op = tf.cond(sequential_update, true_fn, false_fn) return tf.group(train_op, ma_op, name="train_with_avg") def get_config(self): config = { - 'average_decay': self._average_decay, - 'num_updates': self._num_updates, - 'seq_update': self._seq_update + 'average_decay': + self._serialize_hyperparameter('average_decay'), + 'num_updates': + self._serialize_hyperparameter('num_updates'), + 'sequential_update': + self._serialize_hyperparameter('sequential_update') } base_config = self._optimizer.get_config() return dict(list(base_config.items()) + list(config.items())) @@ -103,7 +111,6 @@ def assign_average_vars(self, var_list): opt.assign_average_vars(model.variables) model.save('model.h5') - ``` """ assign = tf.group([v.assign(self._ema.average(v)) for v in var_list]) @@ -113,9 +120,6 @@ def assign_average_vars(self, var_list): def weights(self): return self._optimizer.weights - def _create_slots(self, var_list): - self._optimizer._create_slots(var_list) # pylint: disable=protected-access - def _resource_apply_dense(self, grad, var): return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access diff --git a/tensorflow_addons/optimizers/moving_average_test.py b/tensorflow_addons/optimizers/moving_average_test.py index d839b02faf..32457268f0 100644 --- a/tensorflow_addons/optimizers/moving_average_test.py +++ b/tensorflow_addons/optimizers/moving_average_test.py @@ -21,14 +21,14 @@ import tensorflow as tf -import moving_average +from tensorflow_addons.optimizers import MovingAverage from tensorflow_addons.utils import test_utils class MovingAverageTest(tf.test.TestCase): @test_utils.run_deprecated_v1 def test_run(self): - for seq_update in [True, False]: + for sequential_update in [True, False]: orig_var0 = [1.0, 2.0] orig_var1 = [3.0, 4.0] @@ -38,10 +38,10 @@ def test_run(self): grads0 = tf.constant([0.1, 0.1]) grads1 = tf.constant([0.01, 0.01]) - opt = moving_average.MovingAverage( + opt = MovingAverage( tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5, - seq_update=seq_update) + sequential_update=sequential_update) update = opt.apply_gradients( list(six.moves.zip([grads0, grads1], [var0, var1]))) @@ -55,14 +55,14 @@ def test_run(self): self.assertAllClose(var0.read_value(), [0.8, 1.8]) self.assertAllClose(var1.read_value(), [2.98, 3.98]) - if seq_update: + if sequential_update: self.assertAllClose(ema_var0.read_value(), [0.9, 1.9]) self.assertAllClose(ema_var1.read_value(), [2.99, 3.99]) assign = opt.assign_average_vars([var0, var1]) self.evaluate(assign) - if seq_update: + if sequential_update: self.assertAllClose(self.evaluate(var0), [0.9, 1.9]) self.assertAllClose(self.evaluate(var1), [2.99, 3.99]) @@ -74,7 +74,7 @@ def test_run(self): ]) self.evaluate(perturb) - if seq_update: + if sequential_update: self.assertAllClose(self.evaluate(var0), [1.9, 2.9]) self.assertAllClose(self.evaluate(var1), [4.99, 5.99]) self.assertAllClose(self.evaluate(ema_var0), [3.9, 4.9]) @@ -83,9 +83,9 @@ def test_run(self): @test_utils.run_in_graph_and_eager_modes def test_opt_failure(self): base_opt = None - for seq_update in [True, False]: + for sequential_update in [True, False]: with self.assertRaises(TypeError): - moving_average.MovingAverage(base_opt, 0.5, seq_update) + MovingAverage(base_opt, 0.5, sequential_update) @test_utils.run_deprecated_v1 def test_model_weights_update(self): @@ -99,8 +99,7 @@ def test_model_weights_update(self): model.build(input_shape=[1, 1]) - opt = moving_average.MovingAverage( - tf.keras.optimizers.SGD(lr=2.0), 0.5) + opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5) update = opt.apply_gradients( list(six.moves.zip([grad], model.variables))) @@ -113,6 +112,26 @@ def test_model_weights_update(self): self.evaluate(mean_update) self.assertAllClose(model.variables[0].read_value(), [[0.9]]) + @test_utils.run_in_graph_and_eager_modes + def test_config(self): + sgd_opt = tf.keras.optimizers.SGD( + lr=2.0, nesterov=True, momentum=0.3, decay=0.1) + opt = MovingAverage( + sgd_opt, + average_decay=0.5, + num_updates=100, + sequential_update=False) + config = opt.get_config() + + self.assertEqual(config['average_decay'], 0.5) + self.assertEqual(config['decay'], 0.1) + self.assertEqual(config['learning_rate'], 2.0) + self.assertEqual(config['momentum'], 0.3) + self.assertEqual(config['name'], 'SGD') + self.assertEqual(config['nesterov'], True) + self.assertEqual(config['num_updates'], 100) + self.assertEqual(config['sequential_update'], False) + if __name__ == '__main__': tf.test.main() From ba51cb5f949b57aaaf2637099492d98a34d3378d Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Mon, 29 Apr 2019 19:53:33 +0530 Subject: [PATCH 5/6] Add eager execution support to MovingAverage * Tests modified for static and eager execution * num_updates and sequential_update reverted back to instance variables * Type check of num_updates and sequential_update --- .../optimizers/moving_average.py | 43 ++++++++------- .../optimizers/moving_average_test.py | 53 +++++++++---------- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py index 08611cad50..23252da8ae 100644 --- a/tensorflow_addons/optimizers/moving_average.py +++ b/tensorflow_addons/optimizers/moving_average.py @@ -54,43 +54,46 @@ def __init__(self, raise TypeError( "optimzer is not an object of tf.keras.optimizers.Optimizer") + if num_updates is not None and not isinstance(num_updates, int): + raise TypeError("num_updates must be of integer type") + + if not isinstance(sequential_update, bool): + raise TypeError("sequential_update must be of bool type") + self._optimizer = optimizer - # NoneType cannot be passed to _set_hyper, so we convert it to -1 - # and vice-versa when creating the object using from_config - num_updates = None if num_updates == -1 else num_updates with tf.name_scope(name): self._ema = tf.train.ExponentialMovingAverage( average_decay, num_updates=num_updates) - num_updates = -1 if num_updates is None else num_updates self._set_hyper("average_decay", average_decay) - self._set_hyper("num_updates", num_updates) - self._set_hyper("sequential_update", sequential_update) + self._num_updates = num_updates + self._sequential_update = sequential_update + self._init = True def apply_gradients(self, grads_and_vars, name=None): - train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) var_list = [v for (_, v) in grads_and_vars] - sequential_update = self._get_hyper("sequential_update", tf.bool) - def true_fn(): - with tf.control_dependencies([train_op]): - return self._ema.apply(var_list) + if tf.executing_eagerly() and self._init: + # this to ensure that var_list is registered initially + self._ema.apply(var_list) + self._init = False - def false_fn(): - return self._ema.apply(var_list) + train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) + + if self._sequential_update: + with tf.control_dependencies([train_op]): + ma_op = self._ema.apply(var_list) + else: + ma_op = self._ema.apply(var_list) - ma_op = tf.cond(sequential_update, true_fn, false_fn) return tf.group(train_op, ma_op, name="train_with_avg") def get_config(self): config = { - 'average_decay': - self._serialize_hyperparameter('average_decay'), - 'num_updates': - self._serialize_hyperparameter('num_updates'), - 'sequential_update': - self._serialize_hyperparameter('sequential_update') + 'average_decay': self._serialize_hyperparameter('average_decay'), + 'num_updates': self._num_updates, + 'sequential_update': self._sequential_update } base_config = self._optimizer.get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow_addons/optimizers/moving_average_test.py b/tensorflow_addons/optimizers/moving_average_test.py index 32457268f0..6042695734 100644 --- a/tensorflow_addons/optimizers/moving_average_test.py +++ b/tensorflow_addons/optimizers/moving_average_test.py @@ -26,45 +26,47 @@ class MovingAverageTest(tf.test.TestCase): - @test_utils.run_deprecated_v1 + @test_utils.run_in_graph_and_eager_modes def test_run(self): for sequential_update in [True, False]: - orig_var0 = [1.0, 2.0] - orig_var1 = [3.0, 4.0] - - var0 = tf.Variable(orig_var0) - var1 = tf.Variable(orig_var1) + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) grads0 = tf.constant([0.1, 0.1]) grads1 = tf.constant([0.01, 0.01]) + grads_and_vars = zip([grads0, grads1], [var0, var1]) + opt = MovingAverage( tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5, sequential_update=sequential_update) - update = opt.apply_gradients( - list(six.moves.zip([grads0, grads1], [var0, var1]))) + if not tf.executing_eagerly(): + update = opt.apply_gradients(grads_and_vars) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.evaluate(update) + self.evaluate(update) + else: + opt.apply_gradients(grads_and_vars) + opt.apply_gradients(grads_and_vars) + + self.assertAllClose(var0.read_value(), [0.6, 1.6]) + self.assertAllClose(var1.read_value(), [2.96, 3.96]) ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access - self.evaluate(tf.compat.v1.global_variables_initializer()) - self.evaluate(update) - - self.assertAllClose(var0.read_value(), [0.8, 1.8]) - self.assertAllClose(var1.read_value(), [2.98, 3.98]) - if sequential_update: - self.assertAllClose(ema_var0.read_value(), [0.9, 1.9]) - self.assertAllClose(ema_var1.read_value(), [2.99, 3.99]) + self.assertAllClose(ema_var0.read_value(), [0.75, 1.75]) + self.assertAllClose(ema_var1.read_value(), [2.975, 3.975]) assign = opt.assign_average_vars([var0, var1]) self.evaluate(assign) if sequential_update: - self.assertAllClose(self.evaluate(var0), [0.9, 1.9]) - self.assertAllClose(self.evaluate(var1), [2.99, 3.99]) + self.assertAllClose(var0.read_value(), [0.75, 1.75]) + self.assertAllClose(var1.read_value(), [2.975, 3.975]) perturb = tf.group([ var0.assign_add([1.0, 1.0]), @@ -75,10 +77,10 @@ def test_run(self): self.evaluate(perturb) if sequential_update: - self.assertAllClose(self.evaluate(var0), [1.9, 2.9]) - self.assertAllClose(self.evaluate(var1), [4.99, 5.99]) - self.assertAllClose(self.evaluate(ema_var0), [3.9, 4.9]) - self.assertAllClose(self.evaluate(ema_var1), [6.99, 7.99]) + self.assertAllClose(var0.read_value(), [1.75, 2.75]) + self.assertAllClose(var1.read_value(), [4.975, 5.975]) + self.assertAllClose(ema_var0.read_value(), [3.75, 4.75]) + self.assertAllClose(ema_var1.read_value(), [6.975, 7.975]) @test_utils.run_in_graph_and_eager_modes def test_opt_failure(self): @@ -87,7 +89,7 @@ def test_opt_failure(self): with self.assertRaises(TypeError): MovingAverage(base_opt, 0.5, sequential_update) - @test_utils.run_deprecated_v1 + @test_utils.run_in_graph_and_eager_modes def test_model_weights_update(self): grad = tf.Variable([[0.1]]) model = tf.keras.Sequential([ @@ -96,13 +98,10 @@ def test_model_weights_update(self): kernel_initializer=tf.keras.initializers.Constant([[1.0]]), use_bias=False) ]) - model.build(input_shape=[1, 1]) opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5) - - update = opt.apply_gradients( - list(six.moves.zip([grad], model.variables))) + update = opt.apply_gradients(zip([grad], model.variables)) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(update) From 30b12ee4796ec730bc95b8b509ba8a10bde58a1e Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Mon, 29 Apr 2019 21:07:26 +0530 Subject: [PATCH 6/6] Nit fixes * Remove six import in moving_average_test * Wrap zip objects in list to pass tests in python3 * Fix typos --- tensorflow_addons/optimizers/moving_average.py | 4 ++-- tensorflow_addons/optimizers/moving_average_test.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py index 23252da8ae..4321f89e75 100644 --- a/tensorflow_addons/optimizers/moving_average.py +++ b/tensorflow_addons/optimizers/moving_average.py @@ -52,10 +52,10 @@ def __init__(self, if not isinstance(optimizer, tf.keras.optimizers.Optimizer): raise TypeError( - "optimzer is not an object of tf.keras.optimizers.Optimizer") + "optimizer is not an object of tf.keras.optimizers.Optimizer") if num_updates is not None and not isinstance(num_updates, int): - raise TypeError("num_updates must be of integer type") + raise TypeError("num_updates must be None or of integer type") if not isinstance(sequential_update, bool): raise TypeError("sequential_update must be of bool type") diff --git a/tensorflow_addons/optimizers/moving_average_test.py b/tensorflow_addons/optimizers/moving_average_test.py index 6042695734..681703449a 100644 --- a/tensorflow_addons/optimizers/moving_average_test.py +++ b/tensorflow_addons/optimizers/moving_average_test.py @@ -17,8 +17,6 @@ from __future__ import division from __future__ import print_function -import six - import tensorflow as tf from tensorflow_addons.optimizers import MovingAverage @@ -35,7 +33,7 @@ def test_run(self): grads0 = tf.constant([0.1, 0.1]) grads1 = tf.constant([0.01, 0.01]) - grads_and_vars = zip([grads0, grads1], [var0, var1]) + grads_and_vars = list(zip([grads0, grads1], [var0, var1])) opt = MovingAverage( tf.keras.optimizers.SGD(lr=2.0), @@ -101,7 +99,7 @@ def test_model_weights_update(self): model.build(input_shape=[1, 1]) opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5) - update = opt.apply_gradients(zip([grad], model.variables)) + update = opt.apply_gradients(list(zip([grad], model.variables))) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(update)