Skip to content

Commit 1f14395

Browse files
Get compatible with optimizer migration in TF 2.11 (#2766)
* Get compatible with optimizer migration in TF 2.11 * Fix comments * add adamw to the change * Add type exception
1 parent 9c87ce1 commit 1f14395

File tree

8 files changed

+143
-38
lines changed

8 files changed

+143
-38
lines changed

tensorflow_addons/optimizers/adabelief.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def _resource_apply_dense(self, grad, var):
201201
sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)
202202

203203
m_t = m.assign(
204-
beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking
204+
beta_1_t * m + (1.0 - beta_1_t) * grad,
205+
use_locking=self._use_locking,
205206
)
206207
m_corr_t = m_t / (1.0 - beta_1_power)
207208

tensorflow_addons/optimizers/average_wrapper.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,29 @@
2525
class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
2626
@typechecked
2727
def __init__(
28-
self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs
28+
self,
29+
optimizer: types.Optimizer,
30+
name: str = "AverageOptimizer",
31+
**kwargs,
2932
):
3033
super().__init__(name, **kwargs)
3134

3235
if isinstance(optimizer, str):
33-
optimizer = tf.keras.optimizers.get(optimizer)
36+
if (
37+
hasattr(tf.keras.optimizers, "legacy")
38+
and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
39+
):
40+
optimizer = tf.keras.optimizers.get(
41+
optimizer, use_legacy_optimizer=True
42+
)
43+
else:
44+
optimizer = tf.keras.optimizers.get(optimizer)
3445

35-
if not isinstance(
36-
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
37-
):
46+
if not isinstance(optimizer, KerasLegacyOptimizer):
3847
raise TypeError(
3948
"optimizer is not an object of tf.keras.optimizers.Optimizer "
40-
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
49+
"or tf.keras.optimizers.legacy.Optimizer "
50+
"(if you have tf version >= 2.11.0)."
4151
)
4252

4353
self._optimizer = optimizer
@@ -135,7 +145,8 @@ def assign_average_vars(self, var_list):
135145
try:
136146
assign_ops.append(
137147
var.assign(
138-
self.get_slot(var, "average"), use_locking=self._use_locking
148+
self.get_slot(var, "average"),
149+
use_locking=self._use_locking,
139150
)
140151
)
141152
except Exception as e:

tensorflow_addons/optimizers/constants.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import importlib
1615
import tensorflow as tf
1716

18-
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
17+
if (
18+
hasattr(tf.keras.optimizers, "experimental")
19+
and tf.keras.optimizers.Optimizer.__module__
20+
== tf.keras.optimizers.experimental.Optimizer.__module__
21+
):
22+
# If the default optimizer points to new Keras optimizer, addon optimizers
23+
# should use the legacy path.
1924
KerasLegacyOptimizer = tf.keras.optimizers.legacy.Optimizer
2025
else:
2126
KerasLegacyOptimizer = tf.keras.optimizers.Optimizer

tensorflow_addons/optimizers/lookahead.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,19 @@ def __init__(
7171
super().__init__(name, **kwargs)
7272

7373
if isinstance(optimizer, str):
74-
optimizer = tf.keras.optimizers.get(optimizer)
75-
if not isinstance(
76-
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
77-
):
74+
if (
75+
hasattr(tf.keras.optimizers, "legacy")
76+
and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
77+
):
78+
optimizer = tf.keras.optimizers.get(
79+
optimizer, use_legacy_optimizer=True
80+
)
81+
else:
82+
optimizer = tf.keras.optimizers.get(optimizer)
83+
if not isinstance(optimizer, KerasLegacyOptimizer):
7884
raise TypeError(
7985
"optimizer is not an object of tf.keras.optimizers.Optimizer "
80-
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
86+
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.11.0)."
8187
)
8288

8389
self._optimizer = optimizer
@@ -119,10 +125,12 @@ def _look_ahead_op(self, var):
119125
)
120126
with tf.control_dependencies([step_back]):
121127
slow_update = slow_var.assign(
122-
tf.where(sync_cond, step_back, slow_var), use_locking=self._use_locking
128+
tf.where(sync_cond, step_back, slow_var),
129+
use_locking=self._use_locking,
123130
)
124131
var_update = var.assign(
125-
tf.where(sync_cond, step_back, var), use_locking=self._use_locking
132+
tf.where(sync_cond, step_back, var),
133+
use_locking=self._use_locking,
126134
)
127135
return tf.group(slow_update, var_update)
128136

tensorflow_addons/optimizers/tests/moving_average_test.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def test_run():
3131

3232
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
3333

34-
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
34+
if hasattr(tf.keras.optimizers, "legacy"):
35+
opt = MovingAverage(tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5)
36+
else:
37+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
3538

3639
opt.apply_gradients(grads_and_vars)
3740
opt.apply_gradients(grads_and_vars)
@@ -95,7 +98,10 @@ def test_model_weights_update():
9598
)
9699
model.build(input_shape=[1, 1])
97100

98-
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
101+
if hasattr(tf.keras.optimizers, "legacy"):
102+
opt = MovingAverage(tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5)
103+
else:
104+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
99105
_ = opt.apply_gradients(list(zip([grad], model.variables)))
100106
np.testing.assert_allclose(model.variables[0].read_value(), [[0.8]])
101107
_ = opt.assign_average_vars(model.variables)
@@ -115,8 +121,10 @@ def test_model_dynamic_lr():
115121
]
116122
)
117123
model.build(input_shape=[1, 1])
118-
119-
opt = MovingAverage(tf.keras.optimizers.SGD(lr=1e-3), average_decay=0.5)
124+
if hasattr(tf.keras.optimizers, "legacy"):
125+
opt = MovingAverage(tf.keras.optimizers.legacy.SGD(lr=1e-3), average_decay=0.5)
126+
else:
127+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=1e-3), average_decay=0.5)
120128
_ = opt.apply_gradients(list(zip([grad], model.variables)))
121129
np.testing.assert_allclose(opt.lr.read_value(), 1e-3)
122130
opt.lr = 1e-4
@@ -129,9 +137,20 @@ def test_optimizer_string():
129137

130138

131139
def test_config():
132-
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
140+
if hasattr(tf.keras.optimizers, "legacy"):
141+
sgd_opt = tf.keras.optimizers.legacy.SGD(
142+
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
143+
)
144+
else:
145+
sgd_opt = tf.keras.optimizers.SGD(
146+
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
147+
)
133148
opt = MovingAverage(
134-
sgd_opt, average_decay=0.5, num_updates=None, start_step=5, dynamic_decay=True
149+
sgd_opt,
150+
average_decay=0.5,
151+
num_updates=None,
152+
start_step=5,
153+
dynamic_decay=True,
135154
)
136155
config = opt.get_config()
137156

@@ -177,9 +196,20 @@ def test_fit_simple_linear_model():
177196

178197

179198
def test_serialization():
180-
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
199+
if hasattr(tf.keras.optimizers, "legacy"):
200+
sgd_opt = tf.keras.optimizers.legacy.SGD(
201+
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
202+
)
203+
else:
204+
sgd_opt = tf.keras.optimizers.SGD(
205+
lr=2.0, nesterov=True, momentum=0.3, decay=0.1
206+
)
181207
optimizer = MovingAverage(
182-
sgd_opt, average_decay=0.5, num_updates=None, start_step=5, dynamic_decay=True
208+
sgd_opt,
209+
average_decay=0.5,
210+
num_updates=None,
211+
start_step=5,
212+
dynamic_decay=True,
183213
)
184214
config = tf.keras.optimizers.serialize(optimizer)
185215
new_optimizer = tf.keras.optimizers.deserialize(config)
@@ -215,9 +245,18 @@ def test_dynamic_decay():
215245
grads0 = tf.constant([0.1, 0.1])
216246
grads_and_vars = [(grads0, var0)]
217247

218-
opt = MovingAverage(
219-
tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5, dynamic_decay=True
220-
)
248+
if hasattr(tf.keras.optimizers, "legacy"):
249+
opt = MovingAverage(
250+
tf.keras.optimizers.legacy.SGD(lr=2.0),
251+
average_decay=0.5,
252+
dynamic_decay=True,
253+
)
254+
else:
255+
opt = MovingAverage(
256+
tf.keras.optimizers.SGD(lr=2.0),
257+
average_decay=0.5,
258+
dynamic_decay=True,
259+
)
221260

222261
opt.apply_gradients(grads_and_vars)
223262
opt.apply_gradients(grads_and_vars)
@@ -235,7 +274,12 @@ def test_swap_weight_no_shadow_copy(device):
235274
var = tf.Variable([1.0, 2.0])
236275
grads = tf.constant([0.1, 0.1])
237276

238-
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
277+
if hasattr(tf.keras.optimizers, "legacy"):
278+
opt = MovingAverage(
279+
tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5
280+
)
281+
else:
282+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
239283

240284
@tf.function
241285
def apply_gradients():
@@ -267,7 +311,12 @@ def test_swap_weights(device):
267311
var = tf.Variable([1.0, 2.0])
268312
grads = tf.constant([0.1, 0.1])
269313

270-
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
314+
if hasattr(tf.keras.optimizers, "legacy"):
315+
opt = MovingAverage(
316+
tf.keras.optimizers.legacy.SGD(lr=2.0), average_decay=0.5
317+
)
318+
else:
319+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5)
271320

272321
@tf.function
273322
def apply_gradients():
@@ -314,7 +363,9 @@ def test_no_average_slot():
314363
# They are returned when using model.variables
315364
# but it's unable to assign average slot to them.
316365
vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
317-
max_tokens=max_features, output_mode="int", output_sequence_length=max_len
366+
max_tokens=max_features,
367+
output_mode="int",
368+
output_sequence_length=max_len,
318369
)
319370

320371
vectorize_layer.adapt(["foo", "bar", "baz"])

tensorflow_addons/optimizers/tests/stochastic_weight_averaging_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
def test_averaging():
2929
start_averaging = 0
3030
average_period = 1
31-
sgd = tf.keras.optimizers.SGD(lr=1.0)
31+
if hasattr(tf.keras.optimizers, "legacy"):
32+
sgd = tf.keras.optimizers.legacy.SGD(lr=1.0)
33+
else:
34+
sgd = tf.keras.optimizers.SGD(lr=1.0)
3235
optimizer = SWA(sgd, start_averaging, average_period)
3336

3437
val_0 = [1.0, 1.0]
@@ -81,7 +84,10 @@ def test_assign_batchnorm():
8184
model.add(tf.keras.layers.BatchNormalization())
8285
model.add(tf.keras.layers.Dense(1))
8386

84-
opt = SWA(tf.keras.optimizers.SGD())
87+
if hasattr(tf.keras.optimizers, "legacy"):
88+
opt = SWA(tf.keras.optimizers.legacy.SGD())
89+
else:
90+
opt = SWA(tf.keras.optimizers.SGD())
8591
model.compile(optimizer=opt, loss="mean_squared_error")
8692
model.fit(x, y, epochs=1)
8793

@@ -118,7 +124,10 @@ def test_fit_simple_linear_model():
118124
def test_serialization():
119125
start_averaging = 0
120126
average_period = 1
121-
sgd = tf.keras.optimizers.SGD(lr=1.0)
127+
if hasattr(tf.keras.optimizers, "legacy"):
128+
sgd = tf.keras.optimizers.legacy.SGD(lr=1.0)
129+
else:
130+
sgd = tf.keras.optimizers.SGD(lr=1.0)
122131
optimizer = SWA(sgd, start_averaging, average_period)
123132
config = tf.keras.optimizers.serialize(optimizer)
124133
new_optimizer = tf.keras.optimizers.deserialize(config)

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,13 @@ def from_config(cls, config, custom_objects=None):
127127
return cls(**config)
128128

129129
def minimize(
130-
self, loss, var_list, grad_loss=None, name=None, decay_var_list=None, tape=None
130+
self,
131+
loss,
132+
var_list,
133+
grad_loss=None,
134+
name=None,
135+
decay_var_list=None,
136+
tape=None,
131137
):
132138
"""Minimize `loss` by updating `var_list`.
133139
@@ -354,7 +360,10 @@ class OptimizerWithDecoupledWeightDecay(
354360

355361
@typechecked
356362
def __init__(
357-
self, weight_decay: Union[FloatTensorLike, Callable], *args, **kwargs
363+
self,
364+
weight_decay: Union[FloatTensorLike, Callable],
365+
*args,
366+
**kwargs,
358367
):
359368
# super delegation is necessary here
360369
super().__init__(weight_decay, *args, **kwargs)
@@ -441,8 +450,14 @@ def __init__(
441450
)
442451

443452

453+
if hasattr(tf.keras.optimizers, "legacy"):
454+
ADAM_CLASS = tf.keras.optimizers.legacy.Adam
455+
else:
456+
ADAM_CLASS = tf.keras.optimizers.Adam
457+
458+
444459
@tf.keras.utils.register_keras_serializable(package="Addons")
445-
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
460+
class AdamW(DecoupledWeightDecayExtension, ADAM_CLASS):
446461
"""Optimizer that implements the Adam algorithm with weight decay.
447462
448463
This is an implementation of the AdamW optimizer described in "Decoupled

tools/testing/source_code_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_api_typed():
4141
# Files within this list will be exempt from verification.
4242
exception_list = [
4343
tfa.rnn.PeepholeLSTMCell,
44+
tf.keras.optimizers.Optimizer,
4445
]
4546
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
4647
exception_list.append(tf.keras.optimizers.legacy.Optimizer)
@@ -50,7 +51,10 @@ def test_api_typed():
5051
"https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#about-type-hints"
5152
)
5253
ensure_api_is_typed(
53-
modules_list, exception_list, init_only=True, additional_message=help_message
54+
modules_list,
55+
exception_list,
56+
init_only=True,
57+
additional_message=help_message,
5458
)
5559

5660

@@ -151,6 +155,7 @@ def test_no_experimental_api():
151155
# TODO: remove all elements of the list and remove the allowlist
152156
# This allowlist should not grow. Do not add elements to this list.
153157
allowlist = [
158+
"tensorflow_addons/optimizers/constants.py",
154159
"tensorflow_addons/optimizers/weight_decay_optimizers.py",
155160
"tensorflow_addons/layers/max_unpooling_2d.py",
156161
"tensorflow_addons/image/dense_image_warp.py",

0 commit comments

Comments
 (0)