Skip to content

Commit 9ebc02b

Browse files
committed
Fix indentation of comments, remove call to tf.test.main from optimizer_test_base.
1 parent aa3d7b0 commit 9ebc02b

File tree

2 files changed

+248
-237
lines changed

2 files changed

+248
-237
lines changed

tensorflow_addons/optimizers/optimizer_test_base.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,34 @@
2424

2525
class OptimizerTestBase(tf.test.TestCase):
2626
"""Base class for optimizer tests.
27-
28-
Optimizer tests may inherit from this class and define test functions
29-
using doTest. Usually this should include the functions testSparse,
30-
testBasic, and testBasicCallableParams. See weight_decay_optimizers_test for
31-
an example.
32-
"""
27+
28+
Optimizer tests may inherit from this class and define test
29+
functions using doTest. Usually this should include the functions
30+
testSparse, testBasic, and testBasicCallableParams. See
31+
weight_decay_optimizers_test for an example.
32+
"""
3333

3434
def doTest(self, optimizer, update_fn, params, do_sparse=False):
3535
"""The major test function.
36-
37-
Args:
38-
optimizer: The tensorflow optimizer class to be tested.
39-
update_fn: The numpy update function of the optimizer, the function
40-
signature must be
41-
update_fn(var: np.array,
42-
grad_t: np.array,
43-
slot_vars: dict,
44-
optimizer_params: dict) -> updated_var, updated_slot_vars
45-
Note that slot_vars will be initialized to an empty dictionary for
46-
each variable, initial values should be handled in the update_fn.
47-
params: A dict, the parameters to pass to the construcor of the
48-
optimizer. Either a constant or a callable. This also passed to the
49-
optimizer_params in the update_fn.
50-
do_sparse: If True, test sparse update. Defaults to False, i.e., dense
51-
update.
52-
"""
36+
37+
Args:
38+
optimizer: The tensorflow optimizer class to be tested.
39+
update_fn: The numpy update function of the optimizer, the function
40+
signature must be
41+
update_fn(var: np.array,
42+
grad_t: np.array,
43+
slot_vars: dict,
44+
optimizer_params: dict) -> (updated_var,
45+
updated_slot_vars)
46+
Note that slot_vars will be initialized to an empty dictionary
47+
for each variable, initial values should be handled in the
48+
update_fn.
49+
params: A dict, the parameters to pass to the construcor of the
50+
optimizer. Either a constant or a callable. This also passed to
51+
the optimizer_params in the update_fn.
52+
do_sparse: If True, test sparse update. Defaults to False, i.e.,
53+
dense update.
54+
"""
5355
for i, dtype in enumerate([tf.half, tf.float32, tf.float64]):
5456
with self.session(graph=tf.Graph()):
5557
# Initialize variables for numpy implementation.
@@ -101,15 +103,15 @@ def doTest(self, optimizer, update_fn, params, do_sparse=False):
101103

102104
def doTestSparseRepeatedIndices(self, optimizer, params):
103105
"""Test for repeated indices in sparse updates.
104-
106+
105107
This test verifies that an update with repeated indices is the same as
106108
an update with two times the gradient.
107109
108110
Args:
109-
optimizer: The tensorflow optimizer class to be tested.
110-
params: A dict, the parameters to pass to the construcor of the
111-
optimizer. Either a constant or a callable. This also passed to the
112-
optimizer_params in the update_fn.
111+
optimizer: The tensorflow optimizer class to be tested.
112+
params: A dict, the parameters to pass to the construcor of the
113+
optimizer. Either a constant or a callable. This also passed to
114+
the optimizer_params in the update_fn.
113115
"""
114116
for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]:
115117
with self.cached_session():
@@ -145,7 +147,3 @@ def doTestSparseRepeatedIndices(self, optimizer, params):
145147
self.assertAllClose(
146148
self.evaluate(aggregated_update_var),
147149
self.evaluate(repeated_index_update_var))
148-
149-
150-
if __name__ == "__main__":
151-
tf.test.main()

0 commit comments

Comments
 (0)