Skip to content

Commit 20bb03c

Browse files
committed
Implement MovingAverage optimizer
* Port MovingAverageOptimizer from tf.contrib.opt * Inherits base Keras optimizer_v2 * `swapping_saver` replaced with `assign_average_vars` * Update test cases for TF2.X * Update docs
1 parent 2b8e0af commit 20bb03c

File tree

5 files changed

+249
-0
lines changed

5 files changed

+249
-0
lines changed

tensorflow_addons/optimizers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ py_library(
77
srcs = [
88
"__init__.py",
99
"lazy_adam.py",
10+
"moving_average.py",
1011
],
1112
srcs_version = "PY2AND3",
1213
deps = [

tensorflow_addons/optimizers/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:------------- |:--------------|
66
| lazy_adam | SIG-Addons | [email protected] |
7+
| moving_average | Dheeraj R. Reddy | [email protected] |
78

89
## Components
910
| Submodule | Optimizer | Reference |
1011
|:----------------------- |:---------------------- |:---------|
1112
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
13+
| moving_average | MovingAverage | |
1214

1315

1416
## Contribution Guidelines

tensorflow_addons/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from __future__ import print_function
2020

2121
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
22+
from tensorflow_addons.optimizers.moving_average import MovingAverage
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
22+
23+
24+
@keras_utils.register_keras_custom_object
25+
class MovingAverage(tf.keras.optimizers.Optimizer):
26+
"""Optimizer that computes a moving average of the variables.
27+
28+
Empirically it has been found that using the moving average of the trained
29+
parameters of a deep network is better than using its trained parameters
30+
directly. This optimizer allows you to compute this moving average and swap
31+
the variables at save time so that any code outside of the training loop
32+
will use by default the average values instead of the original ones.
33+
34+
Example of usage:
35+
36+
```python
37+
opt = tf.keras.optimizers.SGD(learning_rate)
38+
opt = tfa.optimizers.MovingAverage(opt)
39+
40+
```
41+
42+
"""
43+
44+
def __init__(self,
45+
optimizer,
46+
average_decay=0.1,
47+
num_updates=None,
48+
seq_update=True,
49+
name="MovingAverage",
50+
**kwargs):
51+
52+
super(MovingAverage, self).__init__(name, **kwargs)
53+
54+
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
55+
raise TypeError(
56+
"optimzer is not an object of tf.keras.optimizers.Optimizer")
57+
58+
self._optimizer = optimizer
59+
60+
with tf.keras.backend.name_scope(self.__class__.__name__):
61+
self._ema = tf.train.ExponentialMovingAverage(
62+
average_decay, num_updates=num_updates)
63+
64+
self._average_decay = average_decay
65+
self._num_updates = num_updates
66+
self._seq_update = seq_update
67+
68+
def _create_slots(self, var_list):
69+
self._optimizer._create_slots(var_list) # pylint: disable=protected-access
70+
71+
def _resource_apply_dense(self, grad, var):
72+
return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access
73+
74+
def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
75+
return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access
76+
grad, var, indices)
77+
78+
def _resource_apply_sparse(self, grad, var, indices):
79+
return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access
80+
81+
def apply_gradients(self, grads_and_vars, name=None):
82+
# pop = tf.print(grads_and_vars)
83+
train_op = self._optimizer.apply_gradients(grads_and_vars, name=name)
84+
var_list = [v for (_, v) in grads_and_vars]
85+
86+
if self._seq_update:
87+
with tf.control_dependencies([train_op]):
88+
ma_op = self._ema.apply(var_list)
89+
else:
90+
ma_op = self._ema.apply(var_list)
91+
92+
return tf.group(train_op, ma_op, name="train_with_avg")
93+
94+
def get_config(self):
95+
config = {
96+
'average_decay': self._average_decay,
97+
'num_updates': self._num_updates,
98+
'seq_update': self._seq_update
99+
}
100+
base_config = self._optimizer.get_config()
101+
return dict(list(base_config.items()) + list(config.items()))
102+
103+
def assign_average_vars(self, var_list):
104+
"""Update variables in var_list with the running mean of the variables.
105+
106+
Example:
107+
```python
108+
model = tf.Sequential([...])
109+
opt = tfa.optimizers.MovingAverage(
110+
tf.keras.optimizers.SGD(lr=2.0), 0.5)
111+
112+
model.compile(opt, ...)
113+
model.fit(x, y, ...)
114+
115+
# Update the weights to their mean before saving
116+
opt.assign_average_vars(model.variables)
117+
118+
model.save('model.h5')
119+
120+
```
121+
"""
122+
assign = tf.group([v.assign(self._ema.average(v)) for v in var_list])
123+
return assign
124+
125+
@property
126+
def weights(self):
127+
return self._optimizer.weights
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for MovingAverage optimizers."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import six
21+
22+
import tensorflow as tf
23+
24+
import moving_average
25+
from tensorflow_addons.utils import test_utils
26+
27+
28+
class MovingAverageTest(tf.test.TestCase):
29+
@test_utils.run_deprecated_v1
30+
def test_run(self):
31+
for seq_update in [True, False]:
32+
orig_var0 = [1.0, 2.0]
33+
orig_var1 = [3.0, 4.0]
34+
35+
var0 = tf.Variable(orig_var0)
36+
var1 = tf.Variable(orig_var1)
37+
38+
grads0 = tf.constant([0.1, 0.1])
39+
grads1 = tf.constant([0.01, 0.01])
40+
41+
opt = moving_average.MovingAverage(
42+
tf.keras.optimizers.SGD(lr=2.0),
43+
average_decay=0.5,
44+
seq_update=seq_update)
45+
46+
update = opt.apply_gradients(
47+
list(six.moves.zip([grads0, grads1], [var0, var1])))
48+
49+
ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access
50+
ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access
51+
52+
self.evaluate(tf.compat.v1.global_variables_initializer())
53+
self.evaluate(update)
54+
55+
self.assertAllClose(var0.read_value(), [0.8, 1.8])
56+
self.assertAllClose(var1.read_value(), [2.98, 3.98])
57+
58+
if seq_update:
59+
self.assertAllClose(ema_var0.read_value(), [0.9, 1.9])
60+
self.assertAllClose(ema_var1.read_value(), [2.99, 3.99])
61+
62+
assign = opt.assign_average_vars([var0, var1])
63+
self.evaluate(assign)
64+
65+
if seq_update:
66+
self.assertAllClose(self.evaluate(var0), [0.9, 1.9])
67+
self.assertAllClose(self.evaluate(var1), [2.99, 3.99])
68+
69+
perturb = tf.group([
70+
var0.assign_add([1.0, 1.0]),
71+
var1.assign_add([2.0, 2.0]),
72+
ema_var0.assign_add([3.0, 3.0]),
73+
ema_var1.assign_add([4.0, 4.0])
74+
])
75+
self.evaluate(perturb)
76+
77+
if seq_update:
78+
self.assertAllClose(self.evaluate(var0), [1.9, 2.9])
79+
self.assertAllClose(self.evaluate(var1), [4.99, 5.99])
80+
self.assertAllClose(self.evaluate(ema_var0), [3.9, 4.9])
81+
self.assertAllClose(self.evaluate(ema_var1), [6.99, 7.99])
82+
83+
@test_utils.run_in_graph_and_eager_modes
84+
def test_opt_failure(self):
85+
base_opt = None
86+
for seq_update in [True, False]:
87+
with self.assertRaises(TypeError):
88+
moving_average.MovingAverage(base_opt, 0.5, seq_update)
89+
90+
@test_utils.run_deprecated_v1
91+
def test_model_weights_update(self):
92+
grad = tf.Variable([[0.1]])
93+
model = tf.keras.Sequential([
94+
tf.keras.layers.Dense(
95+
1,
96+
kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
97+
use_bias=False)
98+
])
99+
100+
model.build(input_shape=[1, 1])
101+
102+
opt = moving_average.MovingAverage(
103+
tf.keras.optimizers.SGD(lr=2.0), 0.5)
104+
105+
update = opt.apply_gradients(
106+
list(six.moves.zip([grad], model.variables)))
107+
108+
self.evaluate(tf.compat.v1.global_variables_initializer())
109+
self.evaluate(update)
110+
self.assertAllClose(model.variables[0].read_value(), [[0.8]])
111+
112+
mean_update = opt.assign_average_vars(model.variables)
113+
self.evaluate(mean_update)
114+
self.assertAllClose(model.variables[0].read_value(), [[0.9]])
115+
116+
117+
if __name__ == '__main__':
118+
tf.test.main()

0 commit comments

Comments
 (0)