Skip to content

Commit d6ccae6

Browse files
committed
Fix code formatting via patch file.
1 parent f8b6bdf commit d6ccae6

File tree

2 files changed

+102
-97
lines changed

2 files changed

+102
-97
lines changed

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,8 @@ def minimize(self,
125125
ValueError: If some of the variables are not `Variable` objects.
126126
"""
127127
self._decay_var_list = set(decay_var_list) if decay_var_list else False
128-
return super(DecoupledWeightDecayExtension,
129-
self).minimize(loss,
130-
var_list=var_list,
131-
grad_loss=grad_loss,
132-
name=name)
128+
return super(DecoupledWeightDecayExtension, self).minimize(
129+
loss, var_list=var_list, grad_loss=grad_loss, name=name)
133130

134131
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
135132
"""Apply gradients to variables.
@@ -152,8 +149,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
152149
ValueError: If none of the variables have gradients.
153150
"""
154151
self._decay_var_list = set(decay_var_list) if decay_var_list else False
155-
return super(DecoupledWeightDecayExtension,
156-
self).apply_gradients(grads_and_vars, name=name)
152+
return super(DecoupledWeightDecayExtension, self).apply_gradients(
153+
grads_and_vars, name=name)
157154

158155
def _decay_weights_op(self, var):
159156
if not self._decay_var_list or var in self._decay_var_list:
@@ -164,8 +161,8 @@ def _decay_weights_op(self, var):
164161

165162
def _decay_weights_sparse_op(self, var, indices):
166163
if not self._decay_var_list or var in self._decay_var_list:
167-
update = (-self._get_hyper('weight_decay', var.dtype) *
168-
tf.gather(var, indices))
164+
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
165+
var, indices))
169166
return self._resource_scatter_add(var, indices, update)
170167
return tf.no_op()
171168

@@ -260,8 +257,8 @@ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
260257

261258
def __init__(self, weight_decay, *args, **kwargs):
262259
# super delegation is necessary here
263-
super(OptimizerWithDecoupledWeightDecay,
264-
self).__init__(weight_decay, *args, **kwargs)
260+
super(OptimizerWithDecoupledWeightDecay, self).__init__(
261+
weight_decay, *args, **kwargs)
265262

266263
return OptimizerWithDecoupledWeightDecay
267264

@@ -331,12 +328,13 @@ def __init__(self,
331328
of learning rate. `lr` is included for backward compatibility,
332329
recommended to use `learning_rate` instead.
333330
"""
334-
super(SGDW, self).__init__(weight_decay,
335-
learning_rate=learning_rate,
336-
momentum=momentum,
337-
nesterov=nesterov,
338-
name=name,
339-
**kwargs)
331+
super(SGDW, self).__init__(
332+
weight_decay,
333+
learning_rate=learning_rate,
334+
momentum=momentum,
335+
nesterov=nesterov,
336+
name=name,
337+
**kwargs)
340338

341339

342340
@keras_utils.register_keras_custom_object
@@ -416,11 +414,12 @@ def __init__(self,
416414
of learning rate. `lr` is included for backward compatibility,
417415
recommended to use `learning_rate` instead.
418416
"""
419-
super(AdamW, self).__init__(weight_decay,
420-
learning_rate=learning_rate,
421-
beta_1=beta_1,
422-
beta_2=beta_2,
423-
epsilon=epsilon,
424-
amsgrad=amsgrad,
425-
name=name,
426-
**kwargs)
417+
super(AdamW, self).__init__(
418+
weight_decay,
419+
learning_rate=learning_rate,
420+
beta_1=beta_1,
421+
beta_2=beta_2,
422+
epsilon=epsilon,
423+
amsgrad=amsgrad,
424+
name=name,
425+
**kwargs)

tensorflow_addons/optimizers/weight_decay_optimizers_test.py

Lines changed: 78 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ def doTest(self, optimizer, update_fn, do_sparse=False,
6969
var1 = tf.Variable(var1_np, name="var1_%d" % i)
7070
if do_sparse:
7171
grads0_np_indices = np.array([0, 1], dtype=np.int32)
72-
grads0 = tf.IndexedSlices(tf.constant(grads0_np),
73-
tf.constant(grads0_np_indices),
74-
tf.constant([2]))
72+
grads0 = tf.IndexedSlices(
73+
tf.constant(grads0_np), tf.constant(grads0_np_indices),
74+
tf.constant([2]))
7575
grads1_np_indices = np.array([0, 1], dtype=np.int32)
76-
grads1 = tf.IndexedSlices(tf.constant(grads1_np),
77-
tf.constant(grads1_np_indices),
78-
tf.constant([2]))
76+
grads1 = tf.IndexedSlices(
77+
tf.constant(grads1_np), tf.constant(grads1_np_indices),
78+
tf.constant([2]))
7979
else:
8080
grads0 = tf.constant(grads0_np)
8181
grads1 = tf.constant(grads1_np)
@@ -94,12 +94,10 @@ def doTest(self, optimizer, update_fn, do_sparse=False,
9494
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
9595
else:
9696
self.evaluate(update)
97-
var0_np, np_slot_vars0 = update_fn(var0_np, grads0_np,
98-
np_slot_vars0,
99-
**optimizer_kwargs)
100-
var1_np, np_slot_vars1 = update_fn(var1_np, grads1_np,
101-
np_slot_vars1,
102-
**optimizer_kwargs)
97+
var0_np, np_slot_vars0 = update_fn(
98+
var0_np, grads0_np, np_slot_vars0, **optimizer_kwargs)
99+
var1_np, np_slot_vars1 = update_fn(
100+
var1_np, grads1_np, np_slot_vars1, **optimizer_kwargs)
103101
# Validate updated params
104102
self.assertAllCloseAccordingToType(var0_np,
105103
self.evaluate(var0))
@@ -129,16 +127,15 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs):
129127
tf.constant([0.2], shape=[1, 1], dtype=dtype),
130128
tf.constant([1]), tf.constant([2, 1]))
131129
opt_repeated = optimizer(**optimizer_kwargs)
132-
repeated_update = opt_repeated.apply_gradients([
133-
(grad_repeated_index, repeated_index_update_var)
134-
])
130+
repeated_update = opt_repeated.apply_gradients(
131+
[(grad_repeated_index, repeated_index_update_var)])
135132
opt_aggregated = optimizer(**optimizer_kwargs)
136-
aggregated_update = opt_aggregated.apply_gradients([
137-
(grad_aggregated, aggregated_update_var)
138-
])
133+
aggregated_update = opt_aggregated.apply_gradients(
134+
[(grad_aggregated, aggregated_update_var)])
139135
self.evaluate(tf.compat.v1.global_variables_initializer())
140-
self.assertAllClose(self.evaluate(aggregated_update_var),
141-
self.evaluate(repeated_index_update_var))
136+
self.assertAllClose(
137+
self.evaluate(aggregated_update_var),
138+
self.evaluate(repeated_index_update_var))
142139
for _ in range(3):
143140
if not tf.executing_eagerly():
144141
self.evaluate(repeated_update)
@@ -148,8 +145,9 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs):
148145
repeated_index_update_var)])
149146
opt_aggregated.apply_gradients([(grad_aggregated,
150147
aggregated_update_var)])
151-
self.assertAllClose(self.evaluate(aggregated_update_var),
152-
self.evaluate(repeated_index_update_var))
148+
self.assertAllClose(
149+
self.evaluate(aggregated_update_var),
150+
self.evaluate(repeated_index_update_var))
153151

154152

155153
def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2,
@@ -162,8 +160,8 @@ def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2,
162160
lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t)
163161
slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t
164162
slot_vars["v"] = beta2 * slot_vars.get("v", 0) + (1 - beta2) * grad_t**2
165-
param_t = (param * (1 - wd) - lr_t * slot_vars["m"] /
166-
(np.sqrt(slot_vars["v"]) + eps))
163+
param_t = (param * (1 - wd) -
164+
lr_t * slot_vars["m"] / (np.sqrt(slot_vars["v"]) + eps))
167165
slot_vars["t"] = t
168166
return param_t, slot_vars
169167

@@ -185,42 +183,46 @@ class AdamWTest(OptimizerTestBase):
185183

186184
@test_utils.run_in_graph_and_eager_modes(reset_test=True)
187185
def testSparse(self):
188-
self.doTest(self.optimizer,
189-
adamw_update_numpy,
190-
do_sparse=True,
191-
learning_rate=0.001,
192-
beta_1=0.9,
193-
beta_2=0.999,
194-
epsilon=1e-8,
195-
weight_decay=WEIGHT_DECAY)
186+
self.doTest(
187+
self.optimizer,
188+
adamw_update_numpy,
189+
do_sparse=True,
190+
learning_rate=0.001,
191+
beta_1=0.9,
192+
beta_2=0.999,
193+
epsilon=1e-8,
194+
weight_decay=WEIGHT_DECAY)
196195

197196
@test_utils.run_in_graph_and_eager_modes(reset_test=True)
198197
def testSparseRepeatedIndices(self):
199-
self.doTestSparseRepeatedIndices(self.optimizer,
200-
learning_rate=0.001,
201-
beta_1=0.9,
202-
beta_2=0.999,
203-
epsilon=1e-8,
204-
weight_decay=WEIGHT_DECAY)
198+
self.doTestSparseRepeatedIndices(
199+
self.optimizer,
200+
learning_rate=0.001,
201+
beta_1=0.9,
202+
beta_2=0.999,
203+
epsilon=1e-8,
204+
weight_decay=WEIGHT_DECAY)
205205

206206
@test_utils.run_in_graph_and_eager_modes(reset_test=True)
207207
def testBasic(self):
208-
self.doTest(self.optimizer,
209-
adamw_update_numpy,
210-
learning_rate=0.001,
211-
beta_1=0.9,
212-
beta_2=0.999,
213-
epsilon=1e-8,
214-
weight_decay=WEIGHT_DECAY)
208+
self.doTest(
209+
self.optimizer,
210+
adamw_update_numpy,
211+
learning_rate=0.001,
212+
beta_1=0.9,
213+
beta_2=0.999,
214+
epsilon=1e-8,
215+
weight_decay=WEIGHT_DECAY)
215216

216217
def testBasicCallableParams(self):
217-
self.doTest(self.optimizer,
218-
adamw_update_numpy,
219-
learning_rate=lambda: 0.001,
220-
beta_1=lambda: 0.9,
221-
beta_2=lambda: 0.999,
222-
epsilon=lambda: 1e-8,
223-
weight_decay=lambda: WEIGHT_DECAY)
218+
self.doTest(
219+
self.optimizer,
220+
adamw_update_numpy,
221+
learning_rate=lambda: 0.001,
222+
beta_1=lambda: 0.9,
223+
beta_2=lambda: 0.999,
224+
epsilon=lambda: 1e-8,
225+
weight_decay=lambda: WEIGHT_DECAY)
224226

225227

226228
class SGDWTest(OptimizerTestBase):
@@ -229,34 +231,38 @@ class SGDWTest(OptimizerTestBase):
229231

230232
@test_utils.run_in_graph_and_eager_modes(reset_test=True)
231233
def testSparse(self):
232-
self.doTest(self.optimizer,
233-
sgdw_update_numpy,
234-
do_sparse=True,
235-
learning_rate=0.001,
236-
momentum=0.9,
237-
weight_decay=WEIGHT_DECAY)
234+
self.doTest(
235+
self.optimizer,
236+
sgdw_update_numpy,
237+
do_sparse=True,
238+
learning_rate=0.001,
239+
momentum=0.9,
240+
weight_decay=WEIGHT_DECAY)
238241

239242
@test_utils.run_in_graph_and_eager_modes(reset_test=True)
240243
def testSparseRepeatedIndices(self):
241-
self.doTestSparseRepeatedIndices(self.optimizer,
242-
learning_rate=0.001,
243-
momentum=0.9,
244-
weight_decay=WEIGHT_DECAY)
244+
self.doTestSparseRepeatedIndices(
245+
self.optimizer,
246+
learning_rate=0.001,
247+
momentum=0.9,
248+
weight_decay=WEIGHT_DECAY)
245249

246250
@test_utils.run_in_graph_and_eager_modes(reset_test=True)
247251
def testBasic(self):
248-
self.doTest(self.optimizer,
249-
sgdw_update_numpy,
250-
learning_rate=0.001,
251-
momentum=0.9,
252-
weight_decay=WEIGHT_DECAY)
252+
self.doTest(
253+
self.optimizer,
254+
sgdw_update_numpy,
255+
learning_rate=0.001,
256+
momentum=0.9,
257+
weight_decay=WEIGHT_DECAY)
253258

254259
def testBasicCallableParams(self):
255-
self.doTest(self.optimizer,
256-
sgdw_update_numpy,
257-
learning_rate=lambda: 0.001,
258-
momentum=lambda: 0.9,
259-
weight_decay=lambda: WEIGHT_DECAY)
260+
self.doTest(
261+
self.optimizer,
262+
sgdw_update_numpy,
263+
learning_rate=lambda: 0.001,
264+
momentum=lambda: 0.9,
265+
weight_decay=lambda: WEIGHT_DECAY)
260266

261267

262268
class ExtendWithWeightDecayTest(SGDWTest):

0 commit comments

Comments
 (0)