Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tensorflow_addons/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ py_library(
srcs = [
"__init__.py",
"lazy_adam.py",
"lookahead.py",
"moving_average.py",
"rectified_adam.py",
"weight_decay_optimizers.py",
],
srcs_version = "PY2AND3",
Expand All @@ -29,6 +31,19 @@ py_test(
],
)

py_test(
name = "lookahead_test",
size = "small",
srcs = [
"lookahead_test.py",
],
main = "lookahead_test.py",
srcs_version = "PY2AND3",
deps = [
":optimizers",
],
)

py_test(
name = "moving_average_test",
size = "small",
Expand All @@ -42,6 +57,19 @@ py_test(
],
)

py_test(
name = "rectified_adam_test",
size = "small",
srcs = [
"rectified_adam_test.py",
],
main = "rectified_adam_test.py",
srcs_version = "PY2AND3",
deps = [
":optimizers",
],
)

py_test(
name = "weight_decay_optimizers_test",
size = "small",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_addons/optimizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| lazy_adam | Saishruthi Swaminathan | [email protected] |
| lookahead | Zhao Hanguang | [email protected] |
| moving_average | Dheeraj R. Reddy | [email protected] |
| rectified_adam | Zhao Hanguang | [email protected] |
| weight_decay_optimizers | Phil Jund | [email protected] |


## Components
| Submodule | Optimizer | Reference |
|:--------- |:---------- |:---------|
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 |
| moving_average | MovingAverage | |
| rectified_adam | RectifiedAdam | https://arxiv.org/pdf/1908.03265v1.pdf |
| weight_decay_optimizers | SGDW, AdamW, extend_with_decoupled_weight_decay | https://arxiv.org/pdf/1711.05101.pdf |


Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from __future__ import print_function

from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
from tensorflow_addons.optimizers.moving_average import MovingAverage
from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam
from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW
from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW
from tensorflow_addons.optimizers.weight_decay_optimizers import (
Expand Down
171 changes: 171 additions & 0 deletions tensorflow_addons/optimizers/lookahead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils


@keras_utils.register_keras_custom_object
class Lookahead(tf.keras.optimizers.Optimizer):
"""This class allows to extend optimizers with the lookahead mechanism.

The mechanism is proposed by Michael R. Zhang et.al in the paper
[Lookahead Optimizer: k steps forward, 1 step back]
(https://arxiv.org/abs/1907.08610v1). The optimizer iteratively updates two
sets of weights: the search directions for weights are chosen by the inner
optimizer, while the "slow weights" are updated each `k` steps based on the
directions of the "fast weights" and the two sets of weights are
synchronized. This method improves the learning stability and lowers the
variance of its inner optimizer.

Example of usage:

```python
opt = tf.keras.optimizers.SGD(learning_rate)
opt = tfa.optimizers.Lookahead(opt)
```
"""

def __init__(self,
optimizer,
sync_period=6,
slow_step_size=0.5,
name="Lookahead",
**kwargs):
r"""Wrap optimizer with the lookahead mechanism.

Args:
optimizer: The original optimizer that will be used to compute
and apply the gradients.
sync_period: An integer. The synchronization period of lookahead.
Enable lookahead mechanism by setting it with a positive value.
slow_step_size: A floating point value.
The ratio for updating the slow weights.
name: Optional name for the operations created when applying
gradients. Defaults to "Lookahead".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients
by norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse
decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
"""
super(Lookahead, self).__init__(name, **kwargs)

if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer")

self._optimizer = optimizer
self._set_hyper('sync_period', sync_period)
self._set_hyper('slow_step_size', slow_step_size)
self._initialized = False

def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
for var in var_list:
self.add_slot(var, 'slow')

def _create_hypers(self):
self._optimizer._create_hypers() # pylint: disable=protected-access

def _prepare(self, var_list):
return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access

def apply_gradients(self, grads_and_vars, name=None):
self._optimizer._iterations = self.iterations # pylint: disable=protected-access
return super(Lookahead, self).apply_gradients(grads_and_vars, name)

def _init_op(self, var):
slow_var = self.get_slot(var, 'slow')
return slow_var.assign(
tf.where(
tf.equal(self.iterations,
tf.constant(0, dtype=self.iterations.dtype)),
var,
slow_var,
),
use_locking=self._use_locking)

def _look_ahead_op(self, var):
var_dtype = var.dtype.base_dtype
slow_var = self.get_slot(var, 'slow')
local_step = tf.cast(self.iterations + 1, tf.dtypes.int64)
sync_period = self._get_hyper('sync_period', tf.dtypes.int64)
slow_step_size = self._get_hyper('slow_step_size', var_dtype)
step_back = slow_var + slow_step_size * (var - slow_var)
sync_cond = tf.equal(
tf.math.floordiv(local_step, sync_period) * sync_period,
local_step)
with tf.control_dependencies([step_back]):
slow_update = slow_var.assign(
tf.where(
sync_cond,
step_back,
slow_var,
),
use_locking=self._use_locking)
var_update = var.assign(
tf.where(
sync_cond,
step_back,
var,
),
use_locking=self._use_locking)
return tf.group(slow_update, var_update)

@property
def weights(self):
return self._weights + self._optimizer.weights

def _resource_apply_dense(self, grad, var):
init_op = self._init_op(var)
with tf.control_dependencies([init_op]):
train_op = self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access
with tf.control_dependencies([train_op]):
look_ahead_op = self._look_ahead_op(var)
return tf.group(init_op, train_op, look_ahead_op)

def _resource_apply_sparse(self, grad, var, indices):
init_op = self._init_op(var)
with tf.control_dependencies([init_op]):
train_op = self._optimizer._resource_apply_sparse( # pylint: disable=protected-access
grad, var, indices)
with tf.control_dependencies([train_op]):
look_ahead_op = self._look_ahead_op(var)
return tf.group(init_op, train_op, look_ahead_op)

def get_config(self):
config = {
'optimizer': tf.keras.optimizers.serialize(self._optimizer),
'sync_period': self._serialize_hyperparameter('sync_period'),
'slow_step_size': self._serialize_hyperparameter('slow_step_size'),
}
base_config = super(Lookahead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop('optimizer'),
custom_objects=custom_objects,
)
return cls(optimizer, **config)
140 changes: 140 additions & 0 deletions tensorflow_addons/optimizers/lookahead_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Lookahead optimizer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from tensorflow_addons.utils import test_utils
from tensorflow_addons.optimizers import Lookahead


@test_utils.run_all_in_graph_and_eager_modes
class LookaheadTest(tf.test.TestCase):
def run_dense_sample(self, iterations, optimizer, seed=0x2019):
np.random.seed(seed)

val_0 = np.random.random((2,))
val_1 = np.random.random((2,))

var_0 = tf.Variable(val_0, dtype=tf.dtypes.float32)
var_1 = tf.Variable(val_1, dtype=tf.dtypes.float32)

grad_0 = tf.constant(
np.random.standard_normal((2,)), dtype=tf.dtypes.float32)
grad_1 = tf.constant(
np.random.standard_normal((2,)), dtype=tf.dtypes.float32)

grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1]))

if tf.executing_eagerly():
for _ in range(iterations):
optimizer.apply_gradients(grads_and_vars)
else:
update = optimizer.apply_gradients(grads_and_vars)
self.evaluate(tf.compat.v1.global_variables_initializer())
for _ in range(iterations):
self.evaluate(update)

return [val_0, val_1], [self.evaluate(var_0), self.evaluate(var_1)]

def run_sparse_sample(self, iterations, optimizer, seed=0x2019):
np.random.seed(seed)

val_0 = np.random.random((2,))
val_1 = np.random.random((2,))

var_0 = tf.Variable(val_0, dtype=tf.dtypes.float32)
var_1 = tf.Variable(val_1, dtype=tf.dtypes.float32)

grad_0 = tf.IndexedSlices(
tf.constant([np.random.standard_normal()]), tf.constant([0]),
tf.constant([2]))
grad_1 = tf.IndexedSlices(
tf.constant([np.random.standard_normal()]), tf.constant([1]),
tf.constant([2]))

grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1]))

if tf.executing_eagerly():
for _ in range(iterations):
optimizer.apply_gradients(grads_and_vars)
else:
update = optimizer.apply_gradients(grads_and_vars)
self.evaluate(tf.compat.v1.global_variables_initializer())
for _ in range(iterations):
self.evaluate(update)

return [val_0, val_1], [self.evaluate(var_0), self.evaluate(var_1)]

def test_dense_exact_ratio(self):
for k in [5, 10, 100]:
for alpha in [0.3, 0.7]:
optimizer = tf.keras.optimizers.get('adam')
vals, quick_vars = self.run_dense_sample(k, optimizer)
optimizer = Lookahead(
'adam', sync_period=k, slow_step_size=alpha)
_, slow_vars = self.run_dense_sample(k, optimizer)
for val, quick, slow in zip(vals, quick_vars, slow_vars):
expected = val + (quick - val) * alpha
self.assertAllClose(expected, slow)

def test_sparse_exact_ratio(self):
for k in [5, 10, 100]:
for alpha in [0.3, 0.7]:
optimizer = tf.keras.optimizers.get('adam')
vals, quick_vars = self.run_sparse_sample(k, optimizer)
optimizer = Lookahead(
'adam', sync_period=k, slow_step_size=alpha)
_, slow_vars = self.run_sparse_sample(k, optimizer)
for val, quick, slow in zip(vals, quick_vars, slow_vars):
expected = val + (quick - val) * alpha
self.assertAllClose(expected, slow)

def test_fit_simple_linear_model(self):
np.random.seed(0x2019)

x = np.random.standard_normal((100000, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((100000, 1)) * 1e-4

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
model.compile(Lookahead('adam'), loss='mse')

model.fit(x, y, epochs=3)

x = np.random.standard_normal((100, 3))
y = np.dot(x, w)
predicted = model.predict(x)

max_abs_diff = np.max(np.abs(predicted - y))
self.assertLess(max_abs_diff, 1e-4)

def test_get_config(self):
opt = Lookahead('adam', sync_period=10, slow_step_size=0.4)
opt = tf.keras.optimizers.deserialize(
tf.keras.optimizers.serialize(opt))
config = opt.get_config()
self.assertEqual(config['sync_period'], 10)
self.assertEqual(config['slow_step_size'], 0.4)


if __name__ == '__main__':
tf.test.main()
Loading