@@ -37,35 +37,35 @@ class DecoupledWeightDecayExtension(object):
3737 optimizers with decoupled weight decay. We explicitly define the two
3838 examples used in the above paper (SGDW and AdamW), but in general this
3939 can extend any OptimizerX by using
40- `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
40+ `extend_with_decoupled_weight_decay(
41+ OptimizerX, weight_decay=weight_decay)`.
4142 In order for it to work, it must be the first class the Optimizer with
4243 weight decay inherits from, e.g.
4344
4445 ```python
45- class AdamWOptimizer (DecoupledWeightDecayExtension, adam.AdamOptimizer ):
46+ class AdamW (DecoupledWeightDecayExtension, tf.keras.optimizers.Adam ):
4647 def __init__(self, weight_decay, *args, **kwargs):
47- super(AdamWOptimizer , self).__init__(weight_decay, *args, **kwargs).
48+ super(AdamW , self).__init__(weight_decay, *args, **kwargs).
4849 ```
4950
50- Note that this extension decays weights BEFORE applying the update based
51+ Note: this extension decays weights BEFORE applying the update based
5152 on the gradient, i.e. this extension only has the desired behaviour for
5253 optimizers which do not depend on the value of'var' in the update step!
5354
5455 Note: when applying a decay to the learning rate, be sure to manually apply
5556 the decay to the `weight_decay` as well. For example:
5657
5758 ```python
58- schedule = tf.train.piecewise_constant(
59- tf.train.get_global_step(), [10000, 15000], [1e-0, 1e-1, 1e-2])
60- lr = 1e-1 * schedule()
61- wd = lambda: 1e-4 * schedule()
59+ step = tf.Variable(0, trainable=False)
60+ schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
61+ [10000, 15000], [1e-0, 1e-1, 1e-2])
62+ # lr and wd can be a function or a tensor
63+ lr = 1e-1 * schedule(step)
64+ wd = lambda: 1e-4 * schedule(step)
6265
63- # ...
66+ # ...
6467
65- optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr,
66- weight_decay=wd,
67- momentum=0.9,
68- use_nesterov=True)
68+ optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
6969 ```
7070 """
7171
@@ -78,10 +78,10 @@ def __init__(self, weight_decay, **kwargs):
7878 **kwargs: Optional list or tuple or set of `Variable` objects to
7979 decay.
8080 """
81+ wd = kwargs .pop ('weight_decay' , weight_decay )
8182 super (DecoupledWeightDecayExtension , self ).__init__ (** kwargs )
8283 self ._decay_var_list = None # is set in minimize or apply_gradients
83- self ._set_hyper ('weight_decay' , kwargs .get ('weight_decay' ,
84- weight_decay ))
84+ self ._set_hyper ('weight_decay' , wd )
8585
8686 def get_config (self ):
8787 config = super (DecoupledWeightDecayExtension , self ).get_config ()
@@ -188,8 +188,8 @@ def extend_with_decoupled_weight_decay(base_optimizer):
188188 Returns an optimizer class. An instance of the returned class computes the
189189 update step of `base_optimizer` and additionally decays the weights.
190190 E.g., the class returned by
191- `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer )` is equivalent
192- to `tf.contrib.opt.AdamWOptimizer `.
191+ `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam )` is
192+ equivalent to `tfa.optimizers.AdamW `.
193193
194194 The API of the new optimizer class slightly differs from the API of the
195195 base optimizer:
@@ -201,18 +201,35 @@ def extend_with_decoupled_weight_decay(base_optimizer):
201201 Usage example:
202202 ```python
203203 # MyAdamW is a new class
204- MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer )
204+ MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam )
205205 # Create a MyAdamW object
206206 optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
207- sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
207+ # update var1, var2 but only decay var1
208+ optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
208209
209- Note that this extension decays weights BEFORE applying the update based
210+ Note: this extension decays weights BEFORE applying the update based
210211 on the gradient, i.e. this extension only has the desired behaviour for
211212 optimizers which do not depend on the value of 'var' in the update step!
213+
214+ Note: when applying a decay to the learning rate, be sure to manually apply
215+ the decay to the `weight_decay` as well. For example:
216+
217+ ```python
218+ step = tf.Variable(0, trainable=False)
219+ schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
220+ [10000, 15000], [1e-0, 1e-1, 1e-2])
221+ # lr and wd can be a function or a tensor
222+ lr = 1e-1 * schedule(step)
223+ wd = lambda: 1e-4 * schedule(step)
224+
225+ # ...
226+
227+ optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
212228 ```
213229
214230 Args:
215- base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
231+ base_optimizer: An optimizer class that inherits from
232+ tf.optimizers.Optimizer.
216233
217234 Returns:
218235 A new optimizer class that inherits from DecoupledWeightDecayExtension
@@ -238,34 +255,49 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
238255
239256 def __init__ (self , weight_decay , * args , ** kwargs ):
240257 # super delegation is necessary here
241- # pylint: disable=useless-super-delegation
242258 super (OptimizerWithDecoupledWeightDecay , self ).__init__ (
243259 weight_decay , * args , ** kwargs )
244- # pylint: enable=useless-super-delegation
245260
246261 return OptimizerWithDecoupledWeightDecay
247262
248263
249264@keras_utils .register_keras_custom_object
250- class SGDWOptimizer (DecoupledWeightDecayExtension , tf .keras .optimizers .SGD ):
265+ class SGDW (DecoupledWeightDecayExtension , tf .keras .optimizers .SGD ):
251266 """Optimizer that implements the Momentum algorithm with weight_decay.
252267
253- This is an implementation of the SGDW optimizer described in "Fixing
254- Weight Decay Regularization in Adam " by Loshchilov & Hutter
268+ This is an implementation of the SGDW optimizer described in "Decoupled
269+ Weight Decay Regularization" by Loshchilov & Hutter
255270 (https://arxiv.org/abs/1711.05101)
256271 ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
257- It computes the update step of `train.MomentumOptimizer ` and additionally
272+ It computes the update step of `tf.keras.optimizers.SGD ` and additionally
258273 decays the variable. Note that this is different from adding
259274 L2 regularization on the variables to the loss. Decoupling the weight decay
260275 from other hyperparameters (in particular the learning rate) simplifies
261276 hyperparameter search.
262277
263278 For further information see the documentation of the SGD Optimizer.
264279
265- Note that this optimizer can also be instantiated as
280+ This optimizer can also be instantiated as
266281 ```python
267- extend_with_weight_decay(tf.keras.optimizers.SGD,
268- weight_decay=weight_decay)
282+ extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
283+ weight_decay=weight_decay)
284+ ```
285+
286+ Note: when applying a decay to the learning rate, be sure to manually apply
287+ the decay to the `weight_decay` as well. For example:
288+
289+ ```python
290+ step = tf.Variable(0, trainable=False)
291+ schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
292+ [10000, 15000], [1e-0, 1e-1, 1e-2])
293+ # lr and wd can be a function or a tensor
294+ lr = 1e-1 * schedule(step)
295+ wd = lambda: 1e-4 * schedule(step)
296+
297+ # ...
298+
299+ optimizer = tfa.optimizers.SGDW(
300+ learning_rate=lr, weight_decay=wd, momentum=0.9)
269301 ```
270302 """
271303
@@ -287,14 +319,14 @@ def __init__(self,
287319 nesterov: boolean. Whether to apply Nesterov momentum.
288320 name: Optional name prefix for the operations created when applying
289321 gradients. Defaults to 'SGD'.
290- **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
291- `lr`, `decay`}. `clipnorm` is clip gradients by norm;
292- `clipvalue` is clip gradients by value, `decay` is included for
293- backward compatibility to allow time inverse decay of learning
294- rate. `lr` is included for backward compatibility, recommended
295- to use `learning_rate` instead.
322+ **kwargs: keyword arguments. Allowed to be {`clipnorm`,
323+ `clipvalue`, ` lr`, `decay`}. `clipnorm` is clip gradients by
324+ norm; `clipvalue` is clip gradients by value, `decay` is
325+ included for backward compatibility to allow time inverse decay
326+ of learning rate. `lr` is included for backward compatibility,
327+ recommended to use `learning_rate` instead.
296328 """
297- super (SGDWOptimizer , self ).__init__ (
329+ super (SGDW , self ).__init__ (
298330 weight_decay ,
299331 learning_rate = learning_rate ,
300332 momentum = momentum ,
@@ -304,26 +336,42 @@ def __init__(self,
304336
305337
306338@keras_utils .register_keras_custom_object
307- class AdamWOptimizer (DecoupledWeightDecayExtension , tf .keras .optimizers .Adam ):
339+ class AdamW (DecoupledWeightDecayExtension , tf .keras .optimizers .Adam ):
308340 """Optimizer that implements the Adam algorithm with weight decay.
309341
310- This is an implementation of the AdamW optimizer described in "Fixing
311- Weight Decay Regularization in Adam " by Loshchilov & Hutter
342+ This is an implementation of the AdamW optimizer described in "Decoupled
343+ Weight Decay Regularization" by Loshchilov & Hutter
312344 (https://arxiv.org/abs/1711.05101)
313345 ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
314346
315- It computes the update step of `train.AdamOptimizer ` and additionally
347+ It computes the update step of `tf.keras.optimizers.Adam ` and additionally
316348 decays the variable. Note that this is different from adding L2
317349 regularization on the variables to the loss: it regularizes variables with
318350 large gradients more than L2 regularization would, which was shown to yield
319351 better training loss and generalization error in the paper above.
320352
321353 For further information see the documentation of the Adam Optimizer.
322354
323- Note that this optimizer can also be instantiated as
355+ This optimizer can also be instantiated as
356+ ```python
357+ extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
358+ weight_decay=weight_decay)
359+ ```
360+
361+ Note: when applying a decay to the learning rate, be sure to manually apply
362+ the decay to the `weight_decay` as well. For example:
363+
324364 ```python
325- extend_with_weight_decay(tf.keras.optimizers.SGD,
326- weight_decay=weight_decay)
365+ step = tf.Variable(0, trainable=False)
366+ schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
367+ [10000, 15000], [1e-0, 1e-1, 1e-2])
368+ # lr and wd can be a function or a tensor
369+ lr = 1e-1 * schedule(step)
370+ wd = lambda: 1e-4 * schedule(step)
371+
372+ # ...
373+
374+ optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
327375 ```
328376 """
329377
@@ -364,7 +412,7 @@ def __init__(self,
364412 of learning rate. `lr` is included for backward compatibility,
365413 recommended to use `learning_rate` instead.
366414 """
367- super (AdamWOptimizer , self ).__init__ (
415+ super (AdamW , self ).__init__ (
368416 weight_decay ,
369417 learning_rate = learning_rate ,
370418 beta_1 = beta_1 ,
0 commit comments