@@ -125,8 +125,11 @@ def minimize(self,
125125 ValueError: If some of the variables are not `Variable` objects.
126126 """
127127 self ._decay_var_list = set (decay_var_list ) if decay_var_list else False
128- return super (DecoupledWeightDecayExtension , self ).minimize (
129- loss , var_list = var_list , grad_loss = grad_loss , name = name )
128+ return super (DecoupledWeightDecayExtension ,
129+ self ).minimize (loss ,
130+ var_list = var_list ,
131+ grad_loss = grad_loss ,
132+ name = name )
130133
131134 def apply_gradients (self , grads_and_vars , name = None , decay_var_list = None ):
132135 """Apply gradients to variables.
@@ -149,8 +152,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
149152 ValueError: If none of the variables have gradients.
150153 """
151154 self ._decay_var_list = set (decay_var_list ) if decay_var_list else False
152- return super (DecoupledWeightDecayExtension , self ). apply_gradients (
153- grads_and_vars , name = name )
155+ return super (DecoupledWeightDecayExtension ,
156+ self ). apply_gradients ( grads_and_vars , name = name )
154157
155158 def _decay_weights_op (self , var ):
156159 if not self ._decay_var_list or var in self ._decay_var_list :
@@ -161,8 +164,8 @@ def _decay_weights_op(self, var):
161164
162165 def _decay_weights_sparse_op (self , var , indices ):
163166 if not self ._decay_var_list or var in self ._decay_var_list :
164- update = (- self ._get_hyper ('weight_decay' , var .dtype ) * tf . gather (
165- var , indices ))
167+ update = (- self ._get_hyper ('weight_decay' , var .dtype ) *
168+ tf . gather ( var , indices ))
166169 return self ._resource_scatter_add (var , indices , update )
167170 return tf .no_op ()
168171
@@ -226,17 +229,19 @@ def extend_with_decoupled_weight_decay(base_optimizer):
226229
227230 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
228231 ```
232+
233+ Note: you might want to register your own custom optimizer using
234+ `tf.keras.utils.get_custom_objects()`.
229235
230236 Args:
231- base_optimizer: An optimizer class that inherits from
232- tf.optimizers.Optimizer.
237+ base_optimizer: An optimizer class that inherits from
238+ tf.optimizers.Optimizer.
233239
234240 Returns:
235- A new optimizer class that inherits from DecoupledWeightDecayExtension
236- and base_optimizer.
241+ A new optimizer class that inherits from DecoupledWeightDecayExtension
242+ and base_optimizer.
237243 """
238244
239- @keras_utils .register_keras_custom_object
240245 class OptimizerWithDecoupledWeightDecay (DecoupledWeightDecayExtension ,
241246 base_optimizer ):
242247 """Base_optimizer with decoupled weight decay.
@@ -255,8 +260,8 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
255260
256261 def __init__ (self , weight_decay , * args , ** kwargs ):
257262 # super delegation is necessary here
258- super (OptimizerWithDecoupledWeightDecay , self ). __init__ (
259- weight_decay , * args , ** kwargs )
263+ super (OptimizerWithDecoupledWeightDecay ,
264+ self ). __init__ ( weight_decay , * args , ** kwargs )
260265
261266 return OptimizerWithDecoupledWeightDecay
262267
@@ -326,13 +331,12 @@ def __init__(self,
326331 of learning rate. `lr` is included for backward compatibility,
327332 recommended to use `learning_rate` instead.
328333 """
329- super (SGDW , self ).__init__ (
330- weight_decay ,
331- learning_rate = learning_rate ,
332- momentum = momentum ,
333- nesterov = nesterov ,
334- name = name ,
335- ** kwargs )
334+ super (SGDW , self ).__init__ (weight_decay ,
335+ learning_rate = learning_rate ,
336+ momentum = momentum ,
337+ nesterov = nesterov ,
338+ name = name ,
339+ ** kwargs )
336340
337341
338342@keras_utils .register_keras_custom_object
@@ -412,12 +416,11 @@ def __init__(self,
412416 of learning rate. `lr` is included for backward compatibility,
413417 recommended to use `learning_rate` instead.
414418 """
415- super (AdamW , self ).__init__ (
416- weight_decay ,
417- learning_rate = learning_rate ,
418- beta_1 = beta_1 ,
419- beta_2 = beta_2 ,
420- epsilon = epsilon ,
421- amsgrad = amsgrad ,
422- name = name ,
423- ** kwargs )
419+ super (AdamW , self ).__init__ (weight_decay ,
420+ learning_rate = learning_rate ,
421+ beta_1 = beta_1 ,
422+ beta_2 = beta_2 ,
423+ epsilon = epsilon ,
424+ amsgrad = amsgrad ,
425+ name = name ,
426+ ** kwargs )
0 commit comments