Skip to content

Commit e278852

Browse files
SquadrickWindQAQ
authored andcommitted
Implement MovingAverage optimizer (#215)
* Implement MovingAverage optimizer * Port MovingAverageOptimizer from tf.contrib.opt * Inherits base Keras optimizer_v2 * `swapping_saver` replaced with `assign_average_vars` * Update test cases for TF2.X * Update docs * Add moving_average_test as a py_test in BUILD file * Move internal functions under external functions to improve readability * Refactor code and add test for config * Use _set_hyper() and _get_hyper() instead of member variables for average_decay, num_updates and sequential_update * Remove _create_slots() from MovingAverage * Use _serialize_hyperparameter() in get_config() * Replace if-else with tf.cond() to work with tensors * Use absolute import of tensorflow_addons in moving_average_test.py * Add eager execution support to MovingAverage * Tests modified for static and eager execution * num_updates and sequential_update reverted back to instance variables * Type check of num_updates and sequential_update * Nit fixes * Remove six import in moving_average_test * Wrap zip objects in list to pass tests in python3 * Fix typos
1 parent f705137 commit e278852

File tree

5 files changed

+285
-0
lines changed

5 files changed

+285
-0
lines changed

tensorflow_addons/optimizers/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ py_library(
77
srcs = [
88
"__init__.py",
99
"lazy_adam.py",
10+
"moving_average.py",
1011
],
1112
srcs_version = "PY2AND3",
1213
deps = [
@@ -26,3 +27,16 @@ py_test(
2627
":optimizers",
2728
],
2829
)
30+
31+
py_test(
32+
name = "moving_average_test",
33+
size = "small",
34+
srcs = [
35+
"moving_average_test.py",
36+
],
37+
main = "moving_average_test.py",
38+
srcs_version = "PY2AND3",
39+
deps = [
40+
":optimizers",
41+
],
42+
)

tensorflow_addons/optimizers/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:------------- |:--------------|
66
| lazy_adam | SIG-Addons | [email protected] |
7+
| moving_average | Dheeraj R. Reddy | [email protected] |
78

89
## Components
910
| Submodule | Optimizer | Reference |
1011
|:----------------------- |:---------------------- |:---------|
1112
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
13+
| moving_average | MovingAverage | |
1214

1315

1416
## Contribution Guidelines

tensorflow_addons/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from __future__ import print_function
2020

2121
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
22+
from tensorflow_addons.optimizers.moving_average import MovingAverage
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
22+
23+
24+
@keras_utils.register_keras_custom_object
25+
class MovingAverage(tf.keras.optimizers.Optimizer):
26+
"""Optimizer that computes a moving average of the variables.
27+
28+
Empirically it has been found that using the moving average of the trained
29+
parameters of a deep network is better than using its trained parameters
30+
directly. This optimizer allows you to compute this moving average and swap
31+
the variables at save time so that any code outside of the training loop
32+
will use by default the average values instead of the original ones.
33+
34+
Example of usage:
35+
36+
```python
37+
opt = tf.keras.optimizers.SGD(learning_rate)
38+
opt = tfa.optimizers.MovingAverage(opt)
39+
40+
```
41+
"""
42+
43+
def __init__(self,
44+
optimizer,
45+
average_decay=0.1,
46+
num_updates=None,
47+
sequential_update=True,
48+
name="MovingAverage",
49+
**kwargs):
50+
51+
super(MovingAverage, self).__init__(name, **kwargs)
52+
53+
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
54+
raise TypeError(
55+
"optimizer is not an object of tf.keras.optimizers.Optimizer")
56+
57+
if num_updates is not None and not isinstance(num_updates, int):
58+
raise TypeError("num_updates must be None or of integer type")
59+
60+
if not isinstance(sequential_update, bool):
61+
raise TypeError("sequential_update must be of bool type")
62+
63+
self._optimizer = optimizer
64+
65+
with tf.name_scope(name):
66+
self._ema = tf.train.ExponentialMovingAverage(
67+
average_decay, num_updates=num_updates)
68+
69+
self._set_hyper("average_decay", average_decay)
70+
self._num_updates = num_updates
71+
self._sequential_update = sequential_update
72+
self._init = True
73+
74+
def apply_gradients(self, grads_and_vars, name=None):
75+
var_list = [v for (_, v) in grads_and_vars]
76+
77+
if tf.executing_eagerly() and self._init:
78+
# this to ensure that var_list is registered initially
79+
self._ema.apply(var_list)
80+
self._init = False
81+
82+
train_op = self._optimizer.apply_gradients(grads_and_vars, name=name)
83+
84+
if self._sequential_update:
85+
with tf.control_dependencies([train_op]):
86+
ma_op = self._ema.apply(var_list)
87+
else:
88+
ma_op = self._ema.apply(var_list)
89+
90+
return tf.group(train_op, ma_op, name="train_with_avg")
91+
92+
def get_config(self):
93+
config = {
94+
'average_decay': self._serialize_hyperparameter('average_decay'),
95+
'num_updates': self._num_updates,
96+
'sequential_update': self._sequential_update
97+
}
98+
base_config = self._optimizer.get_config()
99+
return dict(list(base_config.items()) + list(config.items()))
100+
101+
def assign_average_vars(self, var_list):
102+
"""Update variables in var_list with the running mean of the variables.
103+
104+
Example:
105+
```python
106+
model = tf.Sequential([...])
107+
opt = tfa.optimizers.MovingAverage(
108+
tf.keras.optimizers.SGD(lr=2.0), 0.5)
109+
110+
model.compile(opt, ...)
111+
model.fit(x, y, ...)
112+
113+
# Update the weights to their mean before saving
114+
opt.assign_average_vars(model.variables)
115+
116+
model.save('model.h5')
117+
```
118+
"""
119+
assign = tf.group([v.assign(self._ema.average(v)) for v in var_list])
120+
return assign
121+
122+
@property
123+
def weights(self):
124+
return self._optimizer.weights
125+
126+
def _resource_apply_dense(self, grad, var):
127+
return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access
128+
129+
def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
130+
return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access
131+
grad, var, indices)
132+
133+
def _resource_apply_sparse(self, grad, var, indices):
134+
return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for MovingAverage optimizers."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_addons.optimizers import MovingAverage
23+
from tensorflow_addons.utils import test_utils
24+
25+
26+
class MovingAverageTest(tf.test.TestCase):
27+
@test_utils.run_in_graph_and_eager_modes
28+
def test_run(self):
29+
for sequential_update in [True, False]:
30+
var0 = tf.Variable([1.0, 2.0])
31+
var1 = tf.Variable([3.0, 4.0])
32+
33+
grads0 = tf.constant([0.1, 0.1])
34+
grads1 = tf.constant([0.01, 0.01])
35+
36+
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
37+
38+
opt = MovingAverage(
39+
tf.keras.optimizers.SGD(lr=2.0),
40+
average_decay=0.5,
41+
sequential_update=sequential_update)
42+
43+
if not tf.executing_eagerly():
44+
update = opt.apply_gradients(grads_and_vars)
45+
self.evaluate(tf.compat.v1.global_variables_initializer())
46+
self.evaluate(update)
47+
self.evaluate(update)
48+
else:
49+
opt.apply_gradients(grads_and_vars)
50+
opt.apply_gradients(grads_and_vars)
51+
52+
self.assertAllClose(var0.read_value(), [0.6, 1.6])
53+
self.assertAllClose(var1.read_value(), [2.96, 3.96])
54+
55+
ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access
56+
ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access
57+
58+
if sequential_update:
59+
self.assertAllClose(ema_var0.read_value(), [0.75, 1.75])
60+
self.assertAllClose(ema_var1.read_value(), [2.975, 3.975])
61+
62+
assign = opt.assign_average_vars([var0, var1])
63+
self.evaluate(assign)
64+
65+
if sequential_update:
66+
self.assertAllClose(var0.read_value(), [0.75, 1.75])
67+
self.assertAllClose(var1.read_value(), [2.975, 3.975])
68+
69+
perturb = tf.group([
70+
var0.assign_add([1.0, 1.0]),
71+
var1.assign_add([2.0, 2.0]),
72+
ema_var0.assign_add([3.0, 3.0]),
73+
ema_var1.assign_add([4.0, 4.0])
74+
])
75+
self.evaluate(perturb)
76+
77+
if sequential_update:
78+
self.assertAllClose(var0.read_value(), [1.75, 2.75])
79+
self.assertAllClose(var1.read_value(), [4.975, 5.975])
80+
self.assertAllClose(ema_var0.read_value(), [3.75, 4.75])
81+
self.assertAllClose(ema_var1.read_value(), [6.975, 7.975])
82+
83+
@test_utils.run_in_graph_and_eager_modes
84+
def test_opt_failure(self):
85+
base_opt = None
86+
for sequential_update in [True, False]:
87+
with self.assertRaises(TypeError):
88+
MovingAverage(base_opt, 0.5, sequential_update)
89+
90+
@test_utils.run_in_graph_and_eager_modes
91+
def test_model_weights_update(self):
92+
grad = tf.Variable([[0.1]])
93+
model = tf.keras.Sequential([
94+
tf.keras.layers.Dense(
95+
1,
96+
kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
97+
use_bias=False)
98+
])
99+
model.build(input_shape=[1, 1])
100+
101+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5)
102+
update = opt.apply_gradients(list(zip([grad], model.variables)))
103+
104+
self.evaluate(tf.compat.v1.global_variables_initializer())
105+
self.evaluate(update)
106+
self.assertAllClose(model.variables[0].read_value(), [[0.8]])
107+
108+
mean_update = opt.assign_average_vars(model.variables)
109+
self.evaluate(mean_update)
110+
self.assertAllClose(model.variables[0].read_value(), [[0.9]])
111+
112+
@test_utils.run_in_graph_and_eager_modes
113+
def test_config(self):
114+
sgd_opt = tf.keras.optimizers.SGD(
115+
lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
116+
opt = MovingAverage(
117+
sgd_opt,
118+
average_decay=0.5,
119+
num_updates=100,
120+
sequential_update=False)
121+
config = opt.get_config()
122+
123+
self.assertEqual(config['average_decay'], 0.5)
124+
self.assertEqual(config['decay'], 0.1)
125+
self.assertEqual(config['learning_rate'], 2.0)
126+
self.assertEqual(config['momentum'], 0.3)
127+
self.assertEqual(config['name'], 'SGD')
128+
self.assertEqual(config['nesterov'], True)
129+
self.assertEqual(config['num_updates'], 100)
130+
self.assertEqual(config['sequential_update'], False)
131+
132+
133+
if __name__ == '__main__':
134+
tf.test.main()

0 commit comments

Comments
 (0)