Skip to content

Commit 98a42c6

Browse files
committed
Move optimizer_test_base into weight_decay_test for now. In the optimizer tests, optimizer params are now keywords instead of a dict. Fix code in comments to support tf-2.0, naming errors, line length.
1 parent 9ebc02b commit 98a42c6

File tree

4 files changed

+263
-86
lines changed

4 files changed

+263
-86
lines changed

tensorflow_addons/optimizers/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ py_test(
2828
],
2929
)
3030

31-
3231
py_test(
3332
name = "weight_decay_optimizers_test",
3433
size = "small",
3534
srcs = [
3635
"weight_decay_optimizers_test.py",
37-
"optimizer_test_base.py",
3836
],
3937
main = "weight_decay_optimizers_test.py",
4038
srcs_version = "PY2AND3",

tensorflow_addons/optimizers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from __future__ import print_function
2020

2121
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
22-
from tensorflow_addons.optimizers.weight_decay_optimizers import AdamWOptimizer
23-
from tensorflow_addons.optimizers.weight_decay_optimizers import SGDWOptimizer
22+
from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW
23+
from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW
2424
from tensorflow_addons.optimizers.weight_decay_optimizers import (
2525
extend_with_decoupled_weight_decay)

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,35 +37,35 @@ class DecoupledWeightDecayExtension(object):
3737
optimizers with decoupled weight decay. We explicitly define the two
3838
examples used in the above paper (SGDW and AdamW), but in general this
3939
can extend any OptimizerX by using
40-
`extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
40+
`extend_with_decoupled_weight_decay(
41+
OptimizerX, weight_decay=weight_decay)`.
4142
In order for it to work, it must be the first class the Optimizer with
4243
weight decay inherits from, e.g.
4344
4445
```python
45-
class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
46+
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
4647
def __init__(self, weight_decay, *args, **kwargs):
47-
super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs).
48+
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
4849
```
4950
50-
Note that this extension decays weights BEFORE applying the update based
51+
Note: this extension decays weights BEFORE applying the update based
5152
on the gradient, i.e. this extension only has the desired behaviour for
5253
optimizers which do not depend on the value of'var' in the update step!
5354
5455
Note: when applying a decay to the learning rate, be sure to manually apply
5556
the decay to the `weight_decay` as well. For example:
5657
5758
```python
58-
schedule = tf.train.piecewise_constant(
59-
tf.train.get_global_step(), [10000, 15000], [1e-0, 1e-1, 1e-2])
60-
lr = 1e-1 * schedule()
61-
wd = lambda: 1e-4 * schedule()
59+
step = tf.Variable(0, trainable=False)
60+
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
61+
[10000, 15000], [1e-0, 1e-1, 1e-2])
62+
# lr and wd can be a function or a tensor
63+
lr = 1e-1 * schedule(step)
64+
wd = lambda: 1e-4 * schedule(step)
6265
63-
# ...
66+
# ...
6467
65-
optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr,
66-
weight_decay=wd,
67-
momentum=0.9,
68-
use_nesterov=True)
68+
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
6969
```
7070
"""
7171

@@ -78,10 +78,10 @@ def __init__(self, weight_decay, **kwargs):
7878
**kwargs: Optional list or tuple or set of `Variable` objects to
7979
decay.
8080
"""
81+
wd = kwargs.pop('weight_decay', weight_decay)
8182
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
8283
self._decay_var_list = None # is set in minimize or apply_gradients
83-
self._set_hyper('weight_decay', kwargs.get('weight_decay',
84-
weight_decay))
84+
self._set_hyper('weight_decay', wd)
8585

8686
def get_config(self):
8787
config = super(DecoupledWeightDecayExtension, self).get_config()
@@ -188,8 +188,8 @@ def extend_with_decoupled_weight_decay(base_optimizer):
188188
Returns an optimizer class. An instance of the returned class computes the
189189
update step of `base_optimizer` and additionally decays the weights.
190190
E.g., the class returned by
191-
`extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent
192-
to `tf.contrib.opt.AdamWOptimizer`.
191+
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
192+
equivalent to `tfa.optimizers.AdamW`.
193193
194194
The API of the new optimizer class slightly differs from the API of the
195195
base optimizer:
@@ -201,18 +201,35 @@ def extend_with_decoupled_weight_decay(base_optimizer):
201201
Usage example:
202202
```python
203203
# MyAdamW is a new class
204-
MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
204+
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
205205
# Create a MyAdamW object
206206
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
207-
sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
207+
# update var1, var2 but only decay var1
208+
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
208209
209-
Note that this extension decays weights BEFORE applying the update based
210+
Note: this extension decays weights BEFORE applying the update based
210211
on the gradient, i.e. this extension only has the desired behaviour for
211212
optimizers which do not depend on the value of 'var' in the update step!
213+
214+
Note: when applying a decay to the learning rate, be sure to manually apply
215+
the decay to the `weight_decay` as well. For example:
216+
217+
```python
218+
step = tf.Variable(0, trainable=False)
219+
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
220+
[10000, 15000], [1e-0, 1e-1, 1e-2])
221+
# lr and wd can be a function or a tensor
222+
lr = 1e-1 * schedule(step)
223+
wd = lambda: 1e-4 * schedule(step)
224+
225+
# ...
226+
227+
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
212228
```
213229
214230
Args:
215-
base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
231+
base_optimizer: An optimizer class that inherits from
232+
tf.optimizers.Optimizer.
216233
217234
Returns:
218235
A new optimizer class that inherits from DecoupledWeightDecayExtension
@@ -238,34 +255,49 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
238255

239256
def __init__(self, weight_decay, *args, **kwargs):
240257
# super delegation is necessary here
241-
# pylint: disable=useless-super-delegation
242258
super(OptimizerWithDecoupledWeightDecay, self).__init__(
243259
weight_decay, *args, **kwargs)
244-
# pylint: enable=useless-super-delegation
245260

246261
return OptimizerWithDecoupledWeightDecay
247262

248263

249264
@keras_utils.register_keras_custom_object
250-
class SGDWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
265+
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
251266
"""Optimizer that implements the Momentum algorithm with weight_decay.
252267
253-
This is an implementation of the SGDW optimizer described in "Fixing
254-
Weight Decay Regularization in Adam" by Loshchilov & Hutter
268+
This is an implementation of the SGDW optimizer described in "Decoupled
269+
Weight Decay Regularization" by Loshchilov & Hutter
255270
(https://arxiv.org/abs/1711.05101)
256271
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
257-
It computes the update step of `train.MomentumOptimizer` and additionally
272+
It computes the update step of `tf.keras.optimizers.SGD` and additionally
258273
decays the variable. Note that this is different from adding
259274
L2 regularization on the variables to the loss. Decoupling the weight decay
260275
from other hyperparameters (in particular the learning rate) simplifies
261276
hyperparameter search.
262277
263278
For further information see the documentation of the SGD Optimizer.
264279
265-
Note that this optimizer can also be instantiated as
280+
This optimizer can also be instantiated as
266281
```python
267-
extend_with_weight_decay(tf.keras.optimizers.SGD,
268-
weight_decay=weight_decay)
282+
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
283+
weight_decay=weight_decay)
284+
```
285+
286+
Note: when applying a decay to the learning rate, be sure to manually apply
287+
the decay to the `weight_decay` as well. For example:
288+
289+
```python
290+
step = tf.Variable(0, trainable=False)
291+
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
292+
[10000, 15000], [1e-0, 1e-1, 1e-2])
293+
# lr and wd can be a function or a tensor
294+
lr = 1e-1 * schedule(step)
295+
wd = lambda: 1e-4 * schedule(step)
296+
297+
# ...
298+
299+
optimizer = tfa.optimizers.SGDW(
300+
learning_rate=lr, weight_decay=wd, momentum=0.9)
269301
```
270302
"""
271303

@@ -287,14 +319,14 @@ def __init__(self,
287319
nesterov: boolean. Whether to apply Nesterov momentum.
288320
name: Optional name prefix for the operations created when applying
289321
gradients. Defaults to 'SGD'.
290-
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
291-
`lr`, `decay`}. `clipnorm` is clip gradients by norm;
292-
`clipvalue` is clip gradients by value, `decay` is included for
293-
backward compatibility to allow time inverse decay of learning
294-
rate. `lr` is included for backward compatibility, recommended
295-
to use `learning_rate` instead.
322+
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
323+
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
324+
norm; `clipvalue` is clip gradients by value, `decay` is
325+
included for backward compatibility to allow time inverse decay
326+
of learning rate. `lr` is included for backward compatibility,
327+
recommended to use `learning_rate` instead.
296328
"""
297-
super(SGDWOptimizer, self).__init__(
329+
super(SGDW, self).__init__(
298330
weight_decay,
299331
learning_rate=learning_rate,
300332
momentum=momentum,
@@ -304,26 +336,42 @@ def __init__(self,
304336

305337

306338
@keras_utils.register_keras_custom_object
307-
class AdamWOptimizer(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
339+
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
308340
"""Optimizer that implements the Adam algorithm with weight decay.
309341
310-
This is an implementation of the AdamW optimizer described in "Fixing
311-
Weight Decay Regularization in Adam" by Loshchilov & Hutter
342+
This is an implementation of the AdamW optimizer described in "Decoupled
343+
Weight Decay Regularization" by Loshchilov & Hutter
312344
(https://arxiv.org/abs/1711.05101)
313345
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
314346
315-
It computes the update step of `train.AdamOptimizer` and additionally
347+
It computes the update step of `tf.keras.optimizers.Adam` and additionally
316348
decays the variable. Note that this is different from adding L2
317349
regularization on the variables to the loss: it regularizes variables with
318350
large gradients more than L2 regularization would, which was shown to yield
319351
better training loss and generalization error in the paper above.
320352
321353
For further information see the documentation of the Adam Optimizer.
322354
323-
Note that this optimizer can also be instantiated as
355+
This optimizer can also be instantiated as
356+
```python
357+
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
358+
weight_decay=weight_decay)
359+
```
360+
361+
Note: when applying a decay to the learning rate, be sure to manually apply
362+
the decay to the `weight_decay` as well. For example:
363+
324364
```python
325-
extend_with_weight_decay(tf.keras.optimizers.SGD,
326-
weight_decay=weight_decay)
365+
step = tf.Variable(0, trainable=False)
366+
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
367+
[10000, 15000], [1e-0, 1e-1, 1e-2])
368+
# lr and wd can be a function or a tensor
369+
lr = 1e-1 * schedule(step)
370+
wd = lambda: 1e-4 * schedule(step)
371+
372+
# ...
373+
374+
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
327375
```
328376
"""
329377

@@ -364,7 +412,7 @@ def __init__(self,
364412
of learning rate. `lr` is included for backward compatibility,
365413
recommended to use `learning_rate` instead.
366414
"""
367-
super(AdamWOptimizer, self).__init__(
415+
super(AdamW, self).__init__(
368416
weight_decay,
369417
learning_rate=learning_rate,
370418
beta_1=beta_1,

0 commit comments

Comments
 (0)