Skip to content

Commit 077924c

Browse files
committed
Delete optimizer_test_base.py Remove keras object registration in the
factory function.
1 parent 98a42c6 commit 077924c

File tree

4 files changed

+104
-242
lines changed

4 files changed

+104
-242
lines changed

tensorflow_addons/optimizers/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ must:
2727
`@run_in_graph_and_eager_modes` (for test method)
2828
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
2929
decorator.
30-
* Consider inheriting from `OptimizerTestBase`.
3130
* Add a `py_test` to this sub-package's BUILD file.
3231

3332
#### Documentation Requirements

tensorflow_addons/optimizers/optimizer_test_base.py

Lines changed: 0 additions & 149 deletions
This file was deleted.

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,11 @@ def minimize(self,
125125
ValueError: If some of the variables are not `Variable` objects.
126126
"""
127127
self._decay_var_list = set(decay_var_list) if decay_var_list else False
128-
return super(DecoupledWeightDecayExtension, self).minimize(
129-
loss, var_list=var_list, grad_loss=grad_loss, name=name)
128+
return super(DecoupledWeightDecayExtension,
129+
self).minimize(loss,
130+
var_list=var_list,
131+
grad_loss=grad_loss,
132+
name=name)
130133

131134
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
132135
"""Apply gradients to variables.
@@ -149,8 +152,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
149152
ValueError: If none of the variables have gradients.
150153
"""
151154
self._decay_var_list = set(decay_var_list) if decay_var_list else False
152-
return super(DecoupledWeightDecayExtension, self).apply_gradients(
153-
grads_and_vars, name=name)
155+
return super(DecoupledWeightDecayExtension,
156+
self).apply_gradients(grads_and_vars, name=name)
154157

155158
def _decay_weights_op(self, var):
156159
if not self._decay_var_list or var in self._decay_var_list:
@@ -161,8 +164,8 @@ def _decay_weights_op(self, var):
161164

162165
def _decay_weights_sparse_op(self, var, indices):
163166
if not self._decay_var_list or var in self._decay_var_list:
164-
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
165-
var, indices))
167+
update = (-self._get_hyper('weight_decay', var.dtype) *
168+
tf.gather(var, indices))
166169
return self._resource_scatter_add(var, indices, update)
167170
return tf.no_op()
168171

@@ -226,17 +229,19 @@ def extend_with_decoupled_weight_decay(base_optimizer):
226229
227230
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
228231
```
232+
233+
Note: you might want to register your own custom optimizer using
234+
`tf.keras.utils.get_custom_objects()`.
229235
230236
Args:
231-
base_optimizer: An optimizer class that inherits from
232-
tf.optimizers.Optimizer.
237+
base_optimizer: An optimizer class that inherits from
238+
tf.optimizers.Optimizer.
233239
234240
Returns:
235-
A new optimizer class that inherits from DecoupledWeightDecayExtension
236-
and base_optimizer.
241+
A new optimizer class that inherits from DecoupledWeightDecayExtension
242+
and base_optimizer.
237243
"""
238244

239-
@keras_utils.register_keras_custom_object
240245
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
241246
base_optimizer):
242247
"""Base_optimizer with decoupled weight decay.
@@ -255,8 +260,8 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
255260

256261
def __init__(self, weight_decay, *args, **kwargs):
257262
# super delegation is necessary here
258-
super(OptimizerWithDecoupledWeightDecay, self).__init__(
259-
weight_decay, *args, **kwargs)
263+
super(OptimizerWithDecoupledWeightDecay,
264+
self).__init__(weight_decay, *args, **kwargs)
260265

261266
return OptimizerWithDecoupledWeightDecay
262267

@@ -326,13 +331,12 @@ def __init__(self,
326331
of learning rate. `lr` is included for backward compatibility,
327332
recommended to use `learning_rate` instead.
328333
"""
329-
super(SGDW, self).__init__(
330-
weight_decay,
331-
learning_rate=learning_rate,
332-
momentum=momentum,
333-
nesterov=nesterov,
334-
name=name,
335-
**kwargs)
334+
super(SGDW, self).__init__(weight_decay,
335+
learning_rate=learning_rate,
336+
momentum=momentum,
337+
nesterov=nesterov,
338+
name=name,
339+
**kwargs)
336340

337341

338342
@keras_utils.register_keras_custom_object
@@ -412,12 +416,11 @@ def __init__(self,
412416
of learning rate. `lr` is included for backward compatibility,
413417
recommended to use `learning_rate` instead.
414418
"""
415-
super(AdamW, self).__init__(
416-
weight_decay,
417-
learning_rate=learning_rate,
418-
beta_1=beta_1,
419-
beta_2=beta_2,
420-
epsilon=epsilon,
421-
amsgrad=amsgrad,
422-
name=name,
423-
**kwargs)
419+
super(AdamW, self).__init__(weight_decay,
420+
learning_rate=learning_rate,
421+
beta_1=beta_1,
422+
beta_2=beta_2,
423+
epsilon=epsilon,
424+
amsgrad=amsgrad,
425+
name=name,
426+
**kwargs)

0 commit comments

Comments
 (0)