From 8c8791a19c3dd0fb7103bc7adec10a30a0a70782 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Mon, 14 Jan 2019 10:23:10 +0800 Subject: [PATCH 1/4] ENH: copy LazyAdamOptimizer from tf.contrib.opt --- tensorflow_addons/optimizers/BUILD | 23 ++ .../optimizers/python/lazy_adam_optimizer.py | 113 ++++++ .../python/lazy_adam_optimizer_test.py | 369 ++++++++++++++++++ 3 files changed, 505 insertions(+) create mode 100644 tensorflow_addons/optimizers/python/lazy_adam_optimizer.py create mode 100644 tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 3ad427fd87..99314e24a6 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -1,3 +1,26 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) + +py_library( + name = "opt_py", + srcs = [ + "__init__.py", + "python/__init__.py", + "python/lazy_adam_optimizer.py", + ], + srcs_version = "PY2AND3", +) + + +py_test( + name = "lazy_adam_optimizer_test", + srcs = [ + "python/lazy_adam_optimizer_test.py" + ], + main = "python/lazy_adam_optimizer_test.py", + deps = [ + ":opt_py", + ], + srcs_version = "PY2AND3", +) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py new file mode 100644 index 0000000000..acf09bc935 --- /dev/null +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py @@ -0,0 +1,113 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Variant of the Adam optimizer that handles sparse updates more efficiently. + +Compared with the original Adam optimizer, the one in this file can provide a +large improvement in model training throughput for some applications. However, +it provides slightly different semantics than the original Adam algorithm, and +may lead to different empirical results. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import adam + + +class LazyAdamOptimizer(adam.AdamOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py new file mode 100644 index 0000000000..12356f9b4b --- /dev/null +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py @@ -0,0 +1,369 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LazyAdamOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow_addons.optimizers.python import lazy_adam_optimizer + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + @test_util.run_deprecated_v1 + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.evaluate(update) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + @test_util.run_deprecated_v1 + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(minimize_op) + + @parameterized.parameters([False, True]) + @test_util.run_deprecated_v1 + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + @test_util.run_deprecated_v1 + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @test_util.run_deprecated_v1 + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_optimizer.LazyAdamOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(6, len(set(opt.variables()))) + + +if __name__ == "__main__": + test.main() From 1922bbc95a032f7dc1c3b583d43d4864c3b9135a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Mon, 14 Jan 2019 19:22:41 +0800 Subject: [PATCH 2/4] ENH: inherit from keras.optimizer --- .../optimizers/python/lazy_adam_optimizer.py | 64 ++--- .../python/lazy_adam_optimizer_test.py | 228 +++++++----------- 2 files changed, 105 insertions(+), 187 deletions(-) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py index acf09bc935..6fd8e116e2 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py @@ -24,15 +24,15 @@ from __future__ import division from __future__ import print_function +from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops -from tensorflow.python.training import adam -class LazyAdamOptimizer(adam.AdamOptimizer): +class LazyAdamOptimizer(adam.Adam): """Variant of the Adam optimizer that handles sparse updates more efficiently. The original Adam algorithm maintains two moving-average accumulators for @@ -44,62 +44,32 @@ class LazyAdamOptimizer(adam.AdamOptimizer): improvements in model training throughput for some applications. However, it provides slightly different semantics than the original Adam algorithm, and may lead to different empirical results. - """ - - def _apply_sparse(self, grad, var): - beta1_power, beta2_power = self._get_beta_accumulators() - beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) - lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) - beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) - beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) - epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) - lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) - - # \\(m := beta1 * m + (1 - beta1) * g_t\\) - m = self.get_slot(var, "m") - m_t = state_ops.scatter_update(m, grad.indices, - beta1_t * array_ops.gather(m, grad.indices) + - (1 - beta1_t) * grad.values, - use_locking=self._use_locking) - # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) - v = self.get_slot(var, "v") - v_t = state_ops.scatter_update(v, grad.indices, - beta2_t * array_ops.gather(v, grad.indices) + - (1 - beta2_t) * math_ops.square(grad.values), - use_locking=self._use_locking) - - # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, grad.indices) - v_t_slice = array_ops.gather(v_t, grad.indices) - denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = state_ops.scatter_sub(var, grad.indices, - lr * m_t_slice / denominator_slice, - use_locking=self._use_locking) - return control_flow_ops.group(var_update, m_t, v_t) + Note, amsgrad is currently not supported and the argument can only be False. + """ def _resource_apply_sparse(self, grad, var, indices): - beta1_power, beta2_power = self._get_beta_accumulators() - beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) - lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) - beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) - beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) - epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) - lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + beta_1_t = self._get_hyper('beta_1', var_dtype) + beta_2_t = self._get_hyper('beta_2', var_dtype) + local_step = math_ops.cast(self.iterations + 1, var_dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_power = math_ops.pow(beta_2_t, local_step) + epsilon_t = self._get_hyper('epsilon', var_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") - m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_t_slice = beta_1_t * array_ops.gather(m, indices) + (1 - beta_1_t) * grad m_update_op = resource_variable_ops.resource_scatter_update(m.handle, indices, m_t_slice) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") - v_t_slice = (beta2_t * array_ops.gather(v, indices) + - (1 - beta2_t) * math_ops.square(grad)) + v_t_slice = (beta_2_t * array_ops.gather(v, indices) + + (1 - beta_2_t) * math_ops.square(grad)) v_update_op = resource_variable_ops.resource_scatter_update(v.handle, indices, v_t_slice) @@ -110,4 +80,4 @@ def _resource_apply_sparse(self, grad, var, indices): indices, var_slice) - return control_flow_ops.group(var_update_op, m_update_op, v_update_op) + return control_flow_ops.group(*[var_update_op, m_update_op, v_update_op]) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py index 12356f9b4b..e83f9fa2dc 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function -from absl.testing import parameterized import numpy as np from tensorflow.python.eager import context @@ -39,106 +38,97 @@ def adam_update_numpy(param, t, m, v, - alpha=0.001, + lr=0.001, beta1=0.9, beta2=0.999, - epsilon=1e-8): - alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + epsilon=1e-7): + lr_t = lr * np.sqrt(1 - beta2**(t + 1)) / (1 - beta1**(t + 1)) m_t = beta1 * m + (1 - beta1) * g_t v_t = beta2 * v + (1 - beta2) * g_t * g_t - param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + param_t = param - lr_t * m_t / (np.sqrt(v_t) + epsilon) return param_t, m_t, v_t -class AdamOptimizerTest(test.TestCase, parameterized.TestCase): +def get_beta_accumulators(opt, dtype): + local_step = math_ops.cast(opt.iterations + 1, dtype) + beta_1_t = math_ops.cast(opt._get_hyper("beta_1"), dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_t = math_ops.cast(opt._get_hyper("beta_2"), dtype) + beta_2_power = math_ops.pow(beta_2_t, local_step) + return (beta_1_power, beta_2_power) + + +class AdamOptimizerTest(test.TestCase): - @parameterized.parameters([False, True]) @test_util.run_deprecated_v1 - def testSparse(self, use_resource): + def testSparse(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - if use_resource: - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - - grads0_np_indices = np.array([0, 1], dtype=np.int32) + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0_np_indices = np.array([0, 2], dtype=np.int32) grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([2])) - grads1_np_indices = np.array([0, 1], dtype=np.int32) + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([2])) + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), constant_op.constant([3])) opt = lazy_adam_optimizer.LazyAdamOptimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1)) + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) # Run 3 steps of Adam - for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + for t in range(3): + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) self.evaluate(update) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - @parameterized.parameters([False, True]) @test_util.run_deprecated_v1 - def testSparseDevicePlacement(self, use_resource): + def testSparseDevicePlacement(self): for index_dtype in [dtypes.int32, dtypes.int64]: with self.cached_session(force_gpu=test.is_gpu_available()): # If a GPU is available, tests that all optimizer ops can be placed on # it (i.e. they have GPU kernels). - if use_resource: - var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) - else: - var = variables.Variable([[1.0], [2.0]]) - + var = variables.Variable([[1.0], [2.0]]) indices = constant_op.constant([0, 1], dtype=index_dtype) - gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + g_sum = lambda: math_ops.reduce_sum(array_ops.gather(var, indices)) # pylint: disable=cell-var-from-loop optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) - minimize_op = optimizer.minimize(gathered_sum) + minimize_op = optimizer.minimize(g_sum, var_list=[var]) self.evaluate(variables.global_variables_initializer()) self.evaluate(minimize_op) - @parameterized.parameters([False, True]) @test_util.run_deprecated_v1 - def testSparseRepeatedIndices(self, use_resource): + def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): - if use_resource: - repeated_index_update_var = resource_variable_ops.ResourceVariable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = resource_variable_ops.ResourceVariable( - [[1.0], [2.0]], dtype=dtype) - else: - repeated_index_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) grad_repeated_index = ops.IndexedSlices( constant_op.constant( [0.1, 0.1], shape=[2, 1], dtype=dtype), @@ -164,7 +154,7 @@ def testSparseRepeatedIndices(self, use_resource): self.assertAllClose(aggregated_update_var.eval(), repeated_index_update_var.eval()) - def doTestBasic(self, use_resource=False, use_callable_params=False): + def doTestBasic(self, use_callable_params=False): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): with self.session(graph=ops.Graph()): # Initialize variables for numpy implementation. @@ -174,14 +164,10 @@ def doTestBasic(self, use_resource=False, use_callable_params=False): var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - if use_resource: - var0 = resource_variable_ops.ResourceVariable( - var0_np, name="var0_%d" % i) - var1 = resource_variable_ops.ResourceVariable( - var1_np, name="var1_%d" % i) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) @@ -196,58 +182,41 @@ def doTestBasic(self, use_resource=False, use_callable_params=False): epsilon = epsilon() opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - opt_variables = opt.variables() - beta1_power, beta2_power = opt._get_beta_accumulators() - self.assertIsNotNone(beta1_power) - self.assertIsNotNone(beta2_power is not None) - self.assertIn(beta1_power, opt_variables) - self.assertIn(beta2_power, opt_variables) - if not context.executing_eagerly(): - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - beta1_power, beta2_power = opt._get_beta_accumulators() - # Run 3 steps of Adam - for t in range(1, 4): + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) if not context.executing_eagerly(): self.evaluate(update) - elif t > 1: + else: opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - self.assertAllCloseAccordingToType(0.9**(t + 1), - self.evaluate(beta1_power)) - self.assertAllCloseAccordingToType(0.999**(t + 1), - self.evaluate(beta2_power)) - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - if use_resource: - self.assertEqual("var0_%d/Adam:0" % (i,), - opt.get_slot(var=var0, name="m").name) - - def testBasic(self): - with self.cached_session(): - self.doTestBasic(use_resource=False) + self.assertEqual("var0_%d/m:0" % (i,), + opt.get_slot(var0, "m").name) @test_util.run_in_graph_and_eager_modes(reset_test=True) def testResourceBasic(self): - self.doTestBasic(use_resource=True) + self.doTestBasic() def testBasicCallableParams(self): with context.eager_mode(): - self.doTestBasic(use_resource=True, use_callable_params=True) + self.doTestBasic(use_callable_params=True) @test_util.run_deprecated_v1 def testTensorLearningRate(self): @@ -272,20 +241,21 @@ def testTensorLearningRate(self): self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([3.0, 4.0], var1.eval()) - beta1_power, beta2_power = opt._get_beta_accumulators() - + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) # Run 3 steps of Adam - for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) - update.run() + for t in range(3): + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) + self.evaluate(update) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) @test_util.run_deprecated_v1 def testSharing(self): @@ -307,16 +277,18 @@ def testSharing(self): update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) - beta1_power, beta2_power = opt._get_beta_accumulators() + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of intertwined Adam1 and Adam2. - for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + for t in range(3): + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta_2_power)) if t % 2 == 0: update1.run() else: @@ -326,43 +298,19 @@ def testSharing(self): var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testTwoSessions(self): - optimizer = lazy_adam_optimizer.LazyAdamOptimizer() - - with context.eager_mode(): - var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") - grads0 = constant_op.constant(np.array([0.1, 0.1])) - optimizer.apply_gradients([(grads0, var0)]) - - g = ops.Graph() - with g.as_default(): - with self.session(graph=g): - var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") - grads0 = constant_op.constant(np.array([0.1, 0.1])) - optimizer.apply_gradients([(grads0, var0)]) - - gg = ops.Graph() - with gg.as_default(): - with self.session(graph=gg): - var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") - grads0 = constant_op.constant(np.array([0.1, 0.1])) - - # If the optimizer saves any state not keyed by graph the following line - # fails. - optimizer.apply_gradients([(grads0, var0)]) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testSlotsUniqueEager(self): with context.eager_mode(): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) - opt.minimize(lambda: v1 + v2) - # There should be two non-slot variables, and two unique slot variables - # for v1 and v2 respectively. - self.assertEqual(6, len(set(opt.variables()))) + opt.minimize(lambda: v1 + v2, var_list=[v1, v2]) + # There should be iteration, and two unique slot variables for v1 and v2. + self.assertEqual(5, len(set(opt.variables()))) + self.assertEqual( + self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations)) if __name__ == "__main__": From 7837d310c6ac8512a1b74b90a8a06f892833af18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 15 Jan 2019 09:46:23 +0800 Subject: [PATCH 3/4] CLN: use PEP8 code style --- .../optimizers/python/lazy_adam_optimizer.py | 90 +-- .../python/lazy_adam_optimizer_test.py | 551 +++++++++--------- 2 files changed, 334 insertions(+), 307 deletions(-) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py index 6fd8e116e2..337effa48c 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py @@ -14,10 +14,10 @@ # ============================================================================== """Variant of the Adam optimizer that handles sparse updates more efficiently. -Compared with the original Adam optimizer, the one in this file can provide a -large improvement in model training throughput for some applications. However, -it provides slightly different semantics than the original Adam algorithm, and -may lead to different empirical results. +Compared with the original Adam optimizer, the one in this file can +provide a large improvement in model training throughput for some +applications. However, it provides slightly different semantics than the +original Adam algorithm, and may lead to different empirical results. """ from __future__ import absolute_import @@ -33,51 +33,51 @@ class LazyAdamOptimizer(adam.Adam): - """Variant of the Adam optimizer that handles sparse updates more efficiently. + """Variant of the Adam optimizer that handles sparse updates more + efficiently. - The original Adam algorithm maintains two moving-average accumulators for - each trainable variable; the accumulators are updated at every step. - This class provides lazier handling of gradient updates for sparse variables. - It only updates moving-average accumulators for sparse variable indices that - appear in the current batch, rather than updating the accumulators for all - indices. Compared with the original Adam optimizer, it can provide large - improvements in model training throughput for some applications. However, it - provides slightly different semantics than the original Adam algorithm, and - may lead to different empirical results. + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. - Note, amsgrad is currently not supported and the argument can only be False. - """ + Note, amsgrad is currently not supported and the argument can only be False. + """ - def _resource_apply_sparse(self, grad, var, indices): - var_dtype = var.dtype.base_dtype - lr_t = self._decayed_lr(var_dtype) - beta_1_t = self._get_hyper('beta_1', var_dtype) - beta_2_t = self._get_hyper('beta_2', var_dtype) - local_step = math_ops.cast(self.iterations + 1, var_dtype) - beta_1_power = math_ops.pow(beta_1_t, local_step) - beta_2_power = math_ops.pow(beta_2_t, local_step) - epsilon_t = self._get_hyper('epsilon', var_dtype) - lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) + def _resource_apply_sparse(self, grad, var, indices): + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + beta_1_t = self._get_hyper('beta_1', var_dtype) + beta_2_t = self._get_hyper('beta_2', var_dtype) + local_step = math_ops.cast(self.iterations + 1, var_dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_power = math_ops.pow(beta_2_t, local_step) + epsilon_t = self._get_hyper('epsilon', var_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) - # \\(m := beta1 * m + (1 - beta1) * g_t\\) - m = self.get_slot(var, "m") - m_t_slice = beta_1_t * array_ops.gather(m, indices) + (1 - beta_1_t) * grad - m_update_op = resource_variable_ops.resource_scatter_update(m.handle, - indices, - m_t_slice) + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta_1_t * array_ops.gather( + m, indices) + (1 - beta_1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update( + m.handle, indices, m_t_slice) - # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) - v = self.get_slot(var, "v") - v_t_slice = (beta_2_t * array_ops.gather(v, indices) + - (1 - beta_2_t) * math_ops.square(grad)) - v_update_op = resource_variable_ops.resource_scatter_update(v.handle, - indices, - v_t_slice) + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta_2_t * array_ops.gather(v, indices) + + (1 - beta_2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update( + v.handle, indices, v_t_slice) - # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) - var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, - indices, - var_slice) + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub( + var.handle, indices, var_slice) - return control_flow_ops.group(*[var_update_op, m_update_op, v_update_op]) + return control_flow_ops.group( + *[var_update_op, m_update_op, v_update_op]) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py index e83f9fa2dc..8cbdae1794 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py @@ -42,276 +42,303 @@ def adam_update_numpy(param, beta1=0.9, beta2=0.999, epsilon=1e-7): - lr_t = lr * np.sqrt(1 - beta2**(t + 1)) / (1 - beta1**(t + 1)) + lr_t = lr * np.sqrt(1 - beta2**(t + 1)) / (1 - beta1**(t + 1)) - m_t = beta1 * m + (1 - beta1) * g_t - v_t = beta2 * v + (1 - beta2) * g_t * g_t + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t - param_t = param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - return param_t, m_t, v_t + param_t = param - lr_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t def get_beta_accumulators(opt, dtype): - local_step = math_ops.cast(opt.iterations + 1, dtype) - beta_1_t = math_ops.cast(opt._get_hyper("beta_1"), dtype) - beta_1_power = math_ops.pow(beta_1_t, local_step) - beta_2_t = math_ops.cast(opt._get_hyper("beta_2"), dtype) - beta_2_power = math_ops.pow(beta_2_t, local_step) - return (beta_1_power, beta_2_power) + local_step = math_ops.cast(opt.iterations + 1, dtype) + beta_1_t = math_ops.cast(opt._get_hyper("beta_1"), dtype) + beta_1_power = math_ops.pow(beta_1_t, local_step) + beta_2_t = math_ops.cast(opt._get_hyper("beta_2"), dtype) + beta_2_power = math_ops.pow(beta_2_t, local_step) + return (beta_1_power, beta_2_power) class AdamOptimizerTest(test.TestCase): - - @test_util.run_deprecated_v1 - def testSparse(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype) - - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0_np_indices = np.array([0, 2], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np[grads0_np_indices]), - constant_op.constant(grads0_np_indices), constant_op.constant([3])) - grads1_np_indices = np.array([0, 2], dtype=np.int32) - grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np[grads1_np_indices]), - constant_op.constant(grads1_np_indices), constant_op.constant([3])) - opt = lazy_adam_optimizer.LazyAdamOptimizer() - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1)) - - beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) - # Run 3 steps of Adam - for t in range(3): - self.assertAllCloseAccordingToType(0.9**(t + 1), - self.evaluate(beta_1_power)) - self.assertAllCloseAccordingToType(0.999**(t + 1), - self.evaluate(beta_2_power)) - self.evaluate(update) - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - - @test_util.run_deprecated_v1 - def testSparseDevicePlacement(self): - for index_dtype in [dtypes.int32, dtypes.int64]: - with self.cached_session(force_gpu=test.is_gpu_available()): - # If a GPU is available, tests that all optimizer ops can be placed on - # it (i.e. they have GPU kernels). - var = variables.Variable([[1.0], [2.0]]) - indices = constant_op.constant([0, 1], dtype=index_dtype) - g_sum = lambda: math_ops.reduce_sum(array_ops.gather(var, indices)) # pylint: disable=cell-var-from-loop - optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) - minimize_op = optimizer.minimize(g_sum, var_list=[var]) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(minimize_op) - - @test_util.run_deprecated_v1 - def testSparseRepeatedIndices(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - repeated_index_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - grad_repeated_index = ops.IndexedSlices( - constant_op.constant( - [0.1, 0.1], shape=[2, 1], dtype=dtype), - constant_op.constant([1, 1]), - constant_op.constant([2, 1])) - grad_aggregated = ops.IndexedSlices( - constant_op.constant( - [0.2], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), - constant_op.constant([2, 1])) - repeated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() - repeated_update = repeated_update_opt.apply_gradients( - [(grad_repeated_index, repeated_index_update_var)]) - aggregated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() - aggregated_update = aggregated_update_opt.apply_gradients( - [(grad_aggregated, aggregated_update_var)]) - self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(aggregated_update_var.eval(), - repeated_index_update_var.eval()) - for _ in range(3): - repeated_update.run() - aggregated_update.run() - self.assertAllClose(aggregated_update_var.eval(), - repeated_index_update_var.eval()) - - def doTestBasic(self, use_callable_params=False): - for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.session(graph=ops.Graph()): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - var0 = resource_variable_ops.ResourceVariable( - var0_np, name="var0_%d" % i) - var1 = resource_variable_ops.ResourceVariable( - var1_np, name="var1_%d" % i) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - - learning_rate = lambda: 0.001 - beta1 = lambda: 0.9 - beta2 = lambda: 0.999 - epsilon = lambda: 1e-8 - if not use_callable_params: - learning_rate = learning_rate() - beta1 = beta1() - beta2 = beta2() - epsilon = epsilon() - - opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) - if not context.executing_eagerly(): - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - - # Run 3 steps of Adam - for t in range(3): - beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) - self.assertAllCloseAccordingToType(0.9**(t + 1), - self.evaluate(beta_1_power)) - self.assertAllCloseAccordingToType(0.999**(t + 1), - self.evaluate(beta_2_power)) - if not context.executing_eagerly(): - self.evaluate(update) - else: - opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - self.assertEqual("var0_%d/m:0" % (i,), - opt.get_slot(var0, "m").name) - - @test_util.run_in_graph_and_eager_modes(reset_test=True) - def testResourceBasic(self): - self.doTestBasic() - - def testBasicCallableParams(self): - with context.eager_mode(): - self.doTestBasic(use_callable_params=True) - - @test_util.run_deprecated_v1 - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001)) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) - # Run 3 steps of Adam - for t in range(3): - self.assertAllCloseAccordingToType(0.9**(t + 1), - self.evaluate(beta_1_power)) - self.assertAllCloseAccordingToType(0.999**(t + 1), - self.evaluate(beta_2_power)) - self.evaluate(update) - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - - @test_util.run_deprecated_v1 - def testSharing(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = lazy_adam_optimizer.LazyAdamOptimizer() - update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - - beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - - # Run 3 steps of intertwined Adam1 and Adam2. - for t in range(3): - self.assertAllCloseAccordingToType(0.9**(t + 1), - self.evaluate(beta_1_power)) - self.assertAllCloseAccordingToType(0.999**(t + 1), - self.evaluate(beta_2_power)) - if t % 2 == 0: - update1.run() - else: - update2.run() - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - - def testSlotsUniqueEager(self): - with context.eager_mode(): - v1 = resource_variable_ops.ResourceVariable(1.) - v2 = resource_variable_ops.ResourceVariable(1.) - opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) - opt.minimize(lambda: v1 + v2, var_list=[v1, v2]) - # There should be iteration, and two unique slot variables for v1 and v2. - self.assertEqual(5, len(set(opt.variables()))) - self.assertEqual( - self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations)) + @test_util.run_deprecated_v1 + def testSparse(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.0, 0.1], + dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.0, 0.01], + dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0_np_indices = np.array([0, 2], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), + constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), + constant_op.constant([3])) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1)) + + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + # Run 3 steps of Adam + for t in range(3): + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta_2_power)) + self.evaluate(update) + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + @test_util.run_deprecated_v1 + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + g_sum = lambda: math_ops.reduce_sum(array_ops.gather(var, indices)) # pylint: disable=cell-var-from-loop + optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) + minimize_op = optimizer.minimize(g_sum, var_list=[var]) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(minimize_op) + + @test_util.run_deprecated_v1 + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + repeated_index_update_var = variables.Variable([[1.0], [2.0]], + dtype=dtype) + aggregated_update_var = variables.Variable([[1.0], [2.0]], + dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant([0.1, 0.1], shape=[2, 1], + dtype=dtype), + constant_op.constant([1, 1]), constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant([0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_callable_params=False): + for i, dtype in enumerate( + [dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_optimizer.LazyAdamOptimizer( + learning_rate=learning_rate) + if not context.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators( + opt, dtype) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta_2_power)) + if not context.executing_eagerly(): + self.evaluate(update) + else: + opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + self.assertEqual("var0_%d/m:0" % (i, ), + opt.get_slot(var0, "m").name) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic() + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_callable_params=True) + + @test_util.run_deprecated_v1 + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer( + constant_op.constant(0.001)) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + # Run 3 steps of Adam + for t in range(3): + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta_2_power)) + self.evaluate(update) + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + @test_util.run_deprecated_v1 + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update1 = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(3): + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta_1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta_2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) + opt.minimize(lambda: v1 + v2, var_list=[v1, v2]) + # There should be iteration, and two unique slot variables for v1 and v2. + self.assertEqual(5, len(set(opt.variables()))) + self.assertEqual( + self.evaluate(opt.variables()[0]), + self.evaluate(opt.iterations)) if __name__ == "__main__": - test.main() + test.main() From 76843b298efee5a5c617abd76d1dc781850757e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 15 Jan 2019 10:17:12 +0800 Subject: [PATCH 4/4] CLN: minor fix --- tensorflow_addons/optimizers/BUILD | 5 +++-- .../optimizers/python/lazy_adam_optimizer.py | 4 +--- .../optimizers/python/lazy_adam_optimizer_test.py | 10 +++++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index 99314e24a6..dff0f34c88 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) py_library( - name = "opt_py", + name = "optimizers_py", srcs = [ "__init__.py", "python/__init__.py", @@ -15,12 +15,13 @@ py_library( py_test( name = "lazy_adam_optimizer_test", + size = "small", srcs = [ "python/lazy_adam_optimizer_test.py" ], main = "python/lazy_adam_optimizer_test.py", deps = [ - ":opt_py", + ":optimizers_py", ], srcs_version = "PY2AND3", ) diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py index 337effa48c..91e48085f3 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer.py @@ -29,12 +29,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops class LazyAdamOptimizer(adam.Adam): - """Variant of the Adam optimizer that handles sparse updates more - efficiently. + """Variant of the Adam optimizer that handles sparse updates more efficiently. The original Adam algorithm maintains two moving-average accumulators for each trainable variable; the accumulators are updated at every step. diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py index 8cbdae1794..6b7e034045 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py @@ -60,7 +60,9 @@ def get_beta_accumulators(opt, dtype): return (beta_1_power, beta_2_power) -class AdamOptimizerTest(test.TestCase): +class LazyAdamOptimizerTest(test.TestCase): + + # TODO: remove v1 tests (keep pace with adam_test.py in keras). @test_util.run_deprecated_v1 def testSparse(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: @@ -138,9 +140,11 @@ def testSparseRepeatedIndices(self): aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype) grad_repeated_index = ops.IndexedSlices( - constant_op.constant([0.1, 0.1], shape=[2, 1], + constant_op.constant([0.1, 0.1], + shape=[2, 1], dtype=dtype), - constant_op.constant([1, 1]), constant_op.constant([2, 1])) + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) grad_aggregated = ops.IndexedSlices( constant_op.constant([0.2], shape=[1, 1], dtype=dtype), constant_op.constant([1]), constant_op.constant([2, 1]))