diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD index a7a5948428..00a31f327e 100644 --- a/tensorflow_addons/optimizers/BUILD +++ b/tensorflow_addons/optimizers/BUILD @@ -15,6 +15,7 @@ py_library( "rectified_adam.py", "stochastic_weight_averaging.py", "weight_decay_optimizers.py", + "yogi.py", ], deps = [ "//tensorflow_addons/utils", @@ -33,6 +34,18 @@ py_test( ], ) +py_test( + name = "yogi_test", + size = "small", + srcs = [ + "yogi_test.py", + ], + main = "yogi_test.py", + deps = [ + ":optimizers", + ], +) + py_test( name = "conditional_gradient_test", size = "small", diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 807a7163ce..c73e49b2ed 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -12,6 +12,7 @@ | rectified_adam | Zhao Hanguang | cyberzhg@gmail.com | | stochastic_weight_averaging | Shreyash Patodia | patodiashreyash32@gmail.com | | weight_decay_optimizers | Phil Jund | ijund.phil@googlemail.com | +| yogi | Manzil Zaheer | manzilz@google.com | @@ -27,6 +28,7 @@ | rectified_adam | RectifiedAdam | https://arxiv.org/pdf/1908.03265v1.pdf | | stochastic_weight_averaging | SWA | https://arxiv.org/abs/1803.05407.pdf | | weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf | +| yogi | Yogi | https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf | diff --git a/tensorflow_addons/optimizers/yogi.py b/tensorflow_addons/optimizers/yogi.py new file mode 100644 index 0000000000..92abd035cd --- /dev/null +++ b/tensorflow_addons/optimizers/yogi.py @@ -0,0 +1,341 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Yogi: Extension of yogi adaptive nonconvex optimizer in Keras. + +Implementation of Additive Averaging. +m_t+1 = beta1*m_t + (1-beta1)*g_t +v_t+1 = v_t + sign(g_t-v_t)(g_t^2) +Experiments show better performance across NLP and Vision tasks. +Paper: +https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def _solve(a, b, c): + """Return solution of a quadratic minimization. + + The optimization equation is: + f(a, b, c) = argmin_w{1/2 * a * w^2 + b * w + c * |w|} + we get optimal solution w*: + w* = -(b - sign(b)*c)/a if |b| > c else w* = 0 + REQUIRES: Dimensionality of a and b must be same + Args: + a: A Tensor + b: A Tensor + c: A Tensor with one element. + Returns: + A Tensor w, which is solution for the equation + """ + w = (c * tf.sign(b) - b) / a + w = tf.cast(tf.abs(b) > c, dtype=b.dtype) * w + return w + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class Yogi(tf.keras.optimizers.Optimizer): + """Optimizer that implements the Yogi algorithm in Keras. + + See Algorithm 2 of + https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf. + """ + + def __init__(self, + learning_rate=0.01, + beta1=0.9, + beta2=0.999, + epsilon=1e-3, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0, + initial_accumulator_value=1.0, + activation='sign', + name='Yogi', + **kwargs): + """Construct a new Yogi optimizer. + + Args: + learning_rate: A Tensor or a floating point value. + The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A constant trading off adaptivity and noise. + l1_regularization_strength: A float value, must be greater than or + equal to zero. + l2_regularization_strength: A float value, must be greater than or + equal to zero. + initial_accumulator_value: The starting value for accumulators. + Only positive values are allowed. + activation: Use hard sign or soft tanh to determin sign. + name: Optional name for the operations created when applying + gradients. Defaults to "Yogi". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` + is clip gradients by value, `decay` is included for backward + compatibility to allow time inverse decay of learning rate. `lr` + is included for backward compatibility, recommended to use + `learning_rate` instead. + """ + super(Yogi, self).__init__(name, **kwargs) + self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) + self._set_hyper('decay', self._initial_decay) + self._set_hyper('beta_1', beta1) + self._set_hyper('beta_2', beta2) + self._set_hyper('epsilon', epsilon) + self._set_hyper('l1_regularization_strength', + l1_regularization_strength) + self._set_hyper('l2_regularization_strength', + l2_regularization_strength) + + self._beta1 = beta1 + self._activation = activation + self._initial_accumulator_value = initial_accumulator_value + self._l1_regularization_strength = l1_regularization_strength + self._l2_regularization_strength = l2_regularization_strength + + def _create_slots(self, var_list): + """See `tf.train.Optimizer._create_slots()`.""" + # Create slots for the first and second moments, and maximum second moments. + for var in var_list: + init = tf.constant_initializer(self._initial_accumulator_value) + self.add_slot(var, 'v', init) + if self._beta1 > 0.0: + self.add_slot(var, 'm') + + def _resource_apply_dense(self, grad, var): + """See `tf.train.Optimizer._apply_dense()`.""" + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + beta1_t = self._get_hyper('beta_1', var_dtype) + beta2_t = self._get_hyper('beta_2', var_dtype) + epsilon_t = self._get_hyper('epsilon', var_dtype) + l1_t = self._get_hyper('l1_regularization_strength', var_dtype) + l2_t = self._get_hyper('l2_regularization_strength', var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta1_power = tf.pow(beta1_t, local_step) + beta2_power = tf.pow(beta2_t, local_step) + + lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)) + + update_vs = [] + if self._beta1 == 0.0: + # v_t = v + sign(g_t^2-v)(g_t^2) + v = self.get_slot(var, 'v') + grad2 = grad * grad + if self._activation == 'sign': + sign = tf.sign(grad2 - v) + elif self._activation == 'tanh': + sign = tf.tanh(10 * (grad2 - v)) + else: + raise NotImplementedError( + 'Activation function can be sign or tanh') + v_t = v.assign_add( + (1 - beta2_t) * sign * grad2, use_locking=self._use_locking) + v_sqrt = tf.sqrt(v_t) + + # Yogi effective LR + per_coord_lr = lr / (v_sqrt + epsilon_t) + + # Variable update + # Step 1: Gradient descent + new_var = var - per_coord_lr * grad + # Step 2: Prox operator + if self._l1_regularization_strength > 0: + new_var = _solve(1 + l2_t * per_coord_lr, -new_var, + l1_t * per_coord_lr) + elif self._l2_regularization_strength > 0: + new_var = new_var / (1 + l2_t * per_coord_lr) + # Step 3: Update + var_update = var.assign(new_var, use_locking=self._use_locking) + + update_vs.append(var_update) + update_vs.append(v_t) + + else: + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, 'm') + m_t = m.assign( + m * beta1_t + grad * (1 - beta1_t), + use_locking=self._use_locking) + + # v_t = v + sign(g_t^2-v)(g_t^2) + v = self.get_slot(var, 'v') + grad2 = grad * grad + if self._activation == 'sign': + sign = tf.sign(grad2 - v) + elif self._activation == 'tanh': + sign = tf.tanh(10 * (grad2 - v)) + else: + raise NotImplementedError( + 'Activation function can be sign or tanh') + v_t = v.assign_add( + (1 - beta2_t) * sign * grad2, use_locking=self._use_locking) + v_sqrt = tf.sqrt(v_t) + + # Yogi effective LR + per_coord_lr = lr / (v_sqrt + epsilon_t) + + # Variable update + # Step 1: Gradient descent + new_var = var - per_coord_lr * m_t + # Step 2: Prox operator + if self._l1_regularization_strength > 0: + new_var = _solve(1 + l2_t * per_coord_lr, -new_var, + l1_t * per_coord_lr) + elif self._l2_regularization_strength > 0: + new_var = new_var / (1 + l2_t * per_coord_lr) + # Step 3: Update + var_update = var.assign(new_var, use_locking=self._use_locking) + update_vs.append(var_update) + update_vs.append(m_t) + update_vs.append(v_t) + + # Create an op that groups all the above operations + return tf.group(*update_vs) + + def _resource_apply_sparse(self, grad, var, indices): + """Applies sparse gradients to a variable. + + Args: + grad: A tensor for the `values` of `tf.IndexedSlices`. + var: A `tf.Variable` object. + indices: A tensor for the `indices` of `tf.IndexedSlices`. + Returns: + An op which updates `var` with `grad` and `indices`. + """ + + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + beta1_t = self._get_hyper('beta_1', var_dtype) + beta2_t = self._get_hyper('beta_2', var_dtype) + epsilon_t = self._get_hyper('epsilon', var_dtype) + l1_t = self._get_hyper('l1_regularization_strength', var_dtype) + l2_t = self._get_hyper('l2_regularization_strength', var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta1_power = tf.pow(beta1_t, local_step) + beta2_power = tf.pow(beta2_t, local_step) + + lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)) + + update_vs = [] + if self._beta1 == 0.0: + # v_t = v + sign(g_t^2-v)(g_t^2) + v = self.get_slot(var, 'v') + grad2 = grad * grad + v_slice = tf.gather(v, indices) + if self._activation == 'sign': + sign = tf.sign(grad2 - v_slice) + elif self._activation == 'tanh': + sign = tf.tanh(10 * (grad2 - v_slice)) + else: + raise NotImplementedError( + 'Activation function can be sign or tanh') + v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2 + v_t = self._resource_scatter_update(v, indices, v_scaled_g_values) + v_sqrt = tf.sqrt(v_scaled_g_values) + + # Yogi effective LR + per_coord_lr = lr / (v_sqrt + epsilon_t) + + # Variable update + # Step 1: Gradient descent + var_slice = tf.gather(var, indices) + new_var = var_slice - per_coord_lr * grad + # Step 2: Prox operator + if self._l1_regularization_strength > 0: + new_var = _solve(1 + l2_t * per_coord_lr, -new_var, + l1_t * per_coord_lr) + elif self._l2_regularization_strength > 0: + new_var = new_var / (1 + l2_t * per_coord_lr) + # Step 3: Update + var_update = self._resource_scatter_update(var, indices, new_var) + update_vs.append(var_update) + update_vs.append(v_t) + + else: + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, 'm') + m_scaled_g_values = grad * (1 - beta1_t) + m_t = m.assign(m * beta1_t, use_locking=self._use_locking) + with tf.control_dependencies([m_t]): + m_slice = tf.gather(m, indices) + m_scaled_g_values + m_t = self._resource_scatter_update(m, indices, m_slice) + + # v_t = v + sign(g_t^2-v)(g_t^2) + v = self.get_slot(var, 'v') + grad2 = grad * grad + v_slice = tf.gather(v, indices) + if self._activation == 'sign': + sign = tf.sign(grad2 - tf.gather(v, indices)) + elif self._activation == 'tanh': + sign = tf.tanh(10 * (grad2 - tf.gather(v, indices))) + else: + raise NotImplementedError( + 'Activation function can be sign or tanh') + v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2 + v_t = self._resource_scatter_update(v, indices, v_scaled_g_values) + v_sqrt = tf.sqrt(v_scaled_g_values) + + # Yogi effective LR + per_coord_lr = lr / (v_sqrt + epsilon_t) + + # Variable update + # Step 1: Gradient descent + var_slice = tf.gather(var, indices) + new_var = var_slice - per_coord_lr * m_slice + # Step 2: Prox operator + if self._l1_regularization_strength > 0: + new_var = _solve(1 + l2_t * per_coord_lr, -new_var, + l1_t * per_coord_lr) + elif self._l2_regularization_strength > 0: + new_var = new_var / (1 + l2_t * per_coord_lr) + # Step 3: Update + var_update = self._resource_scatter_update(var, indices, new_var) + update_vs.append(var_update) + update_vs.append(m_t) + update_vs.append(v_t) + + # Create an op that groups all the above operations + return tf.group(*update_vs) + + def get_config(self): + config = super(Yogi, self).get_config() + config.update({ + 'learning_rate': + self._serialize_hyperparameter('learning_rate'), + 'decay': + self._serialize_hyperparameter('decay'), + 'beta1': + self._serialize_hyperparameter('beta_1'), + 'beta2': + self._serialize_hyperparameter('beta_2'), + 'epsilon': + self._serialize_hyperparameter('epsilon'), + 'l1_t': + self._serialize_hyperparameter('l1_regularization_strength'), + 'l2_t': + self._serialize_hyperparameter('l2_regularization_strength'), + 'activation': + self._activation, + 'initial_accumulator_value': + self._initial_accumulator_value, + }) + return config diff --git a/tensorflow_addons/optimizers/yogi_test.py b/tensorflow_addons/optimizers/yogi_test.py new file mode 100644 index 0000000000..3050ffb575 --- /dev/null +++ b/tensorflow_addons/optimizers/yogi_test.py @@ -0,0 +1,407 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Yogi optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.optimizers import yogi +from tensorflow_addons.utils import test_utils + + +def yogi_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.01, + beta1=0.9, + beta2=0.999, + epsilon=1e-3, + l1reg=0.0, + l2reg=0.0): + """Performs Yogi parameter update using numpy. + + Args: + param: An numpy ndarray of the current parameter. + g_t: An numpy ndarray of the current gradients. + t: An numpy ndarray of the current time step. + m: An numpy ndarray of the 1st moment estimates. + v: An numpy ndarray of the 2nd moment estimates. + alpha: A float value of the learning rate. + beta1: A float value of the exponential decay rate for the 1st moment + estimates. + beta2: A float value of the exponential decay rate for the 2nd moment + estimates. + epsilon: A float of a small constant for numerical stability. + l1reg: A float value of L1 regularization + l2reg: A float value of L2 regularization + Returns: + A tuple of numpy ndarrays (param_t, m_t, v_t) representing the + updated parameters for `param`, `m`, and `v` respectively. + """ + beta1 = np.array(beta1, dtype=param.dtype) + beta2 = np.array(beta2, dtype=param.dtype) + + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + g2_t = g_t * g_t + v_t = v - (1 - beta2) * np.sign(v - g2_t) * g2_t + + per_coord_lr = alpha_t / (np.sqrt(v_t) + epsilon) + param_t = param - per_coord_lr * m_t + + if l1reg > 0: + param_t = (param_t - l1reg * per_coord_lr * np.sign(param_t)) / ( + 1 + l2reg * per_coord_lr) + print(param_t.dtype) + param_t[np.abs(param_t) < l1reg * per_coord_lr] = 0.0 + elif l2reg > 0: + param_t = param_t / (1 + l2reg * per_coord_lr) + return param_t, m_t, v_t + + +def get_beta_accumulators(opt, dtype): + local_step = tf.cast(opt.iterations + 1, dtype) + beta_1_t = tf.cast(opt._get_hyper("beta_1"), dtype) + beta_1_power = tf.math.pow(beta_1_t, local_step) + beta_2_t = tf.cast(opt._get_hyper("beta_2"), dtype) + beta_2_power = tf.math.pow(beta_2_t, local_step) + return (beta_1_power, beta_2_power) + + +@test_utils.run_all_in_graph_and_eager_modes +class YogiOptimizerTest(tf.test.TestCase): + def _DtypesToTest(self, use_gpu): + if use_gpu: + return [tf.dtypes.float32, tf.dtypes.float64] + else: + return [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64] + + def doTestSparse(self, beta1=0.0, l1reg=0.0, l2reg=0.0): + for dtype in self._DtypesToTest(use_gpu=tf.test.is_gpu_available()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 1.0, 0.0, 1.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np), tf.constant(grads0_np_indices), + tf.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np), tf.constant(grads1_np_indices), + tf.constant([2])) + opt = yogi.Yogi( + beta1=beta1, + l1_regularization_strength=l1reg, + l2_regularization_strength=l2reg) + if not tf.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(tf.compat.v1.global_variables_initializer()) + + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Yogi. + for t in range(1, 4): + beta1_power, beta2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(beta1**t, + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if not tf.executing_eagerly(): + self.evaluate(update) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = yogi_update_numpy( + var0_np, + grads0_np, + t, + m0, + v0, + beta1=beta1, + l1reg=l1reg, + l2reg=l2reg) + var1_np, m1, v1 = yogi_update_numpy( + var1_np, + grads1_np, + t, + m1, + v1, + beta1=beta1, + l1reg=l1reg, + l2reg=l2reg) + + # Validate updated params. + self.assertAllCloseAccordingToType( + var0_np, + self.evaluate(var0), + msg="Updated params 0 do not match in NP and TF") + self.assertAllCloseAccordingToType( + var1_np, + self.evaluate(var1), + msg="Updated params 1 do not match in NP and TF") + + def testSparse(self): + self.doTestSparse() + + def testSparseRegularization(self): + self.doTestSparse(l1reg=0.1, l2reg=0.2) + + def testSparseMomentum(self): + self.doTestSparse(beta1=0.9) + + def testSparseMomentumRegularization(self): + self.doTestSparse(beta1=0.9, l1reg=0.1, l2reg=0.2) + + def testSparseRepeatedIndices(self): + for dtype in self._DtypesToTest(use_gpu=tf.test.is_gpu_available()): + repeated_index_update_var = tf.Variable([[1.0], [2.0]], + dtype=dtype) + aggregated_update_var = tf.Variable([[1.0], [2.0]], dtype=dtype) + grad_repeated_index = tf.IndexedSlices( + tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + tf.constant([1, 1]), tf.constant([2, 1])) + grad_aggregated = tf.IndexedSlices( + tf.constant([0.2], shape=[1, 1], dtype=dtype), + tf.constant([1]), tf.constant([2, 1])) + opt1 = yogi.Yogi() + opt2 = yogi.Yogi() + + if not tf.executing_eagerly(): + repeated_update = opt1.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update = opt2.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + self.evaluate(tf.compat.v1.global_variables_initializer()) + + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + + for _ in range(3): + if not tf.executing_eagerly(): + self.evaluate(repeated_update) + self.evaluate(aggregated_update) + else: + opt1.apply_gradients([(grad_repeated_index, + repeated_index_update_var)]) + opt2.apply_gradients([(grad_aggregated, + aggregated_update_var)]) + + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, beta1=0.0, l1reg=0.0, l2reg=0.0): + for dtype in self._DtypesToTest(use_gpu=tf.test.is_gpu_available()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 1.0, 0.0, 1.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + + opt = yogi.Yogi( + beta1=beta1, + l1_regularization_strength=l1reg, + l2_regularization_strength=l2reg) + + if not tf.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(tf.compat.v1.global_variables_initializer()) + + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Yogi. + for t in range(1, 4): + beta1_power, beta2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(beta1**t, + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + + if not tf.executing_eagerly(): + self.evaluate(update) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = yogi_update_numpy( + var0_np, + grads0_np, + t, + m0, + v0, + beta1=beta1, + l1reg=l1reg, + l2reg=l2reg) + var1_np, m1, v1 = yogi_update_numpy( + var1_np, + grads1_np, + t, + m1, + v1, + beta1=beta1, + l1reg=l1reg, + l2reg=l2reg) + + # Validate updated params. + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def testBasic(self): + self.doTestBasic() + + def testBasicRegularization(self): + self.doTestBasic(l1reg=0.1, l2reg=0.2) + + def testBasicMomentum(self): + self.doTestBasic(beta1=0.9) + + def testBasicMomentumRegularization(self): + self.doTestBasic(beta1=0.9, l1reg=0.1, l2reg=0.2) + + def testTensorLearningRate(self): + for dtype in self._DtypesToTest(use_gpu=tf.test.is_gpu_available()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 1.0, 0.0, 1.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = yogi.Yogi(tf.constant(0.01)) + + if not tf.executing_eagerly(): + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(tf.compat.v1.global_variables_initializer()) + + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Yogi. + for t in range(1, 4): + beta1_power, beta2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(0.9**t, + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + + if not tf.executing_eagerly(): + self.evaluate(update) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = yogi_update_numpy(var0_np, grads0_np, t, m0, + v0) + var1_np, m1, v1 = yogi_update_numpy(var1_np, grads1_np, t, m1, + v1) + + # Validate updated params. + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def testSharing(self): + for dtype in self._DtypesToTest(use_gpu=tf.test.is_gpu_available()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 1.0, 0.0, 1.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = yogi.Yogi() + + if not tf.executing_eagerly(): + update1 = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(tf.compat.v1.global_variables_initializer()) + + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Yogi1 and Yogi2. + for t in range(1, 4): + beta1_power, beta2_power = get_beta_accumulators(opt, dtype) + self.assertAllCloseAccordingToType(0.9**t, + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if not tf.executing_eagerly(): + if t % 2 == 0: + self.evaluate(update1) + else: + self.evaluate(update2) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = yogi_update_numpy(var0_np, grads0_np, t, m0, + v0) + var1_np, m1, v1 = yogi_update_numpy(var1_np, grads1_np, t, m1, + v1) + + # Validate updated params. + self.assertAllCloseAccordingToType(var0_np, + self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, + self.evaluate(var1)) + + def test_get_config(self): + opt = yogi.Yogi(1e-4) + config = opt.get_config() + self.assertEqual(config["learning_rate"], 1e-4) + + +if __name__ == "__main__": + tf.test.main()