Skip to content

Commit 2c4df23

Browse files
authored
Merge pull request tensorflow#11 from facaiy/ENH/move_LazyAdamOptimier
Implement LazyAdamOptimizer
2 parents 2c4a615 + 7cf827f commit 2c4df23

File tree

3 files changed

+453
-0
lines changed

3 files changed

+453
-0
lines changed

tensorflow_addons/optimizers/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
11
licenses(["notice"]) # Apache 2.0
22

33
package(default_visibility = ["//visibility:public"])
4+
5+
py_library(
6+
name = "optimizers_py",
7+
srcs = [
8+
"__init__.py",
9+
"python/__init__.py",
10+
"python/lazy_adam_optimizer.py",
11+
],
12+
srcs_version = "PY2AND3",
13+
)
14+
15+
16+
py_test(
17+
name = "lazy_adam_optimizer_test",
18+
size = "small",
19+
srcs = [
20+
"python/lazy_adam_optimizer_test.py"
21+
],
22+
main = "python/lazy_adam_optimizer_test.py",
23+
deps = [
24+
":optimizers_py",
25+
],
26+
srcs_version = "PY2AND3",
27+
)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Variant of the Adam optimizer that handles sparse updates more efficiently.
16+
17+
Compared with the original Adam optimizer, the one in this file can
18+
provide a large improvement in model training throughput for some
19+
applications. However, it provides slightly different semantics than the
20+
original Adam algorithm, and may lead to different empirical results.
21+
"""
22+
23+
from __future__ import absolute_import
24+
from __future__ import division
25+
from __future__ import print_function
26+
27+
from tensorflow.python.keras.optimizer_v2 import adam
28+
from tensorflow.python.ops import array_ops
29+
from tensorflow.python.ops import control_flow_ops
30+
from tensorflow.python.ops import math_ops
31+
from tensorflow.python.ops import resource_variable_ops
32+
33+
34+
class LazyAdamOptimizer(adam.Adam):
35+
"""Variant of the Adam optimizer that handles sparse updates more efficiently.
36+
37+
The original Adam algorithm maintains two moving-average accumulators for
38+
each trainable variable; the accumulators are updated at every step.
39+
This class provides lazier handling of gradient updates for sparse variables.
40+
It only updates moving-average accumulators for sparse variable indices that
41+
appear in the current batch, rather than updating the accumulators for all
42+
indices. Compared with the original Adam optimizer, it can provide large
43+
improvements in model training throughput for some applications. However, it
44+
provides slightly different semantics than the original Adam algorithm, and
45+
may lead to different empirical results.
46+
47+
Note, amsgrad is currently not supported and the argument can only be False.
48+
"""
49+
50+
def _resource_apply_sparse(self, grad, var, indices):
51+
var_dtype = var.dtype.base_dtype
52+
lr_t = self._decayed_lr(var_dtype)
53+
beta_1_t = self._get_hyper('beta_1', var_dtype)
54+
beta_2_t = self._get_hyper('beta_2', var_dtype)
55+
local_step = math_ops.cast(self.iterations + 1, var_dtype)
56+
beta_1_power = math_ops.pow(beta_1_t, local_step)
57+
beta_2_power = math_ops.pow(beta_2_t, local_step)
58+
epsilon_t = self._get_hyper('epsilon', var_dtype)
59+
lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
60+
61+
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
62+
m = self.get_slot(var, "m")
63+
m_t_slice = beta_1_t * array_ops.gather(
64+
m, indices) + (1 - beta_1_t) * grad
65+
m_update_op = resource_variable_ops.resource_scatter_update(
66+
m.handle, indices, m_t_slice)
67+
68+
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
69+
v = self.get_slot(var, "v")
70+
v_t_slice = (beta_2_t * array_ops.gather(v, indices) +
71+
(1 - beta_2_t) * math_ops.square(grad))
72+
v_update_op = resource_variable_ops.resource_scatter_update(
73+
v.handle, indices, v_t_slice)
74+
75+
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
76+
var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
77+
var_update_op = resource_variable_ops.resource_scatter_sub(
78+
var.handle, indices, var_slice)
79+
80+
return control_flow_ops.group(
81+
*[var_update_op, m_update_op, v_update_op])

0 commit comments

Comments
 (0)