|
24 | 24 |
|
25 | 25 | class OptimizerTestBase(tf.test.TestCase): |
26 | 26 | """Base class for optimizer tests. |
27 | | - |
28 | | - Optimizer tests may inherit from this class and define test functions |
29 | | - using doTest. Usually this should include the functions testSparse, |
30 | | - testBasic, and testBasicCallableParams. See weight_decay_optimizers_test for |
31 | | - an example. |
32 | | - """ |
| 27 | +
|
| 28 | + Optimizer tests may inherit from this class and define test |
| 29 | + functions using doTest. Usually this should include the functions |
| 30 | + testSparse, testBasic, and testBasicCallableParams. See |
| 31 | + weight_decay_optimizers_test for an example. |
| 32 | + """ |
33 | 33 |
|
34 | 34 | def doTest(self, optimizer, update_fn, params, do_sparse=False): |
35 | 35 | """The major test function. |
36 | | - |
37 | | - Args: |
38 | | - optimizer: The tensorflow optimizer class to be tested. |
39 | | - update_fn: The numpy update function of the optimizer, the function |
40 | | - signature must be |
41 | | - update_fn(var: np.array, |
42 | | - grad_t: np.array, |
43 | | - slot_vars: dict, |
44 | | - optimizer_params: dict) -> updated_var, updated_slot_vars |
45 | | - Note that slot_vars will be initialized to an empty dictionary for |
46 | | - each variable, initial values should be handled in the update_fn. |
47 | | - params: A dict, the parameters to pass to the construcor of the |
48 | | - optimizer. Either a constant or a callable. This also passed to the |
49 | | - optimizer_params in the update_fn. |
50 | | - do_sparse: If True, test sparse update. Defaults to False, i.e., dense |
51 | | - update. |
52 | | - """ |
| 36 | +
|
| 37 | + Args: |
| 38 | + optimizer: The tensorflow optimizer class to be tested. |
| 39 | + update_fn: The numpy update function of the optimizer, the function |
| 40 | + signature must be |
| 41 | + update_fn(var: np.array, |
| 42 | + grad_t: np.array, |
| 43 | + slot_vars: dict, |
| 44 | + optimizer_params: dict) -> (updated_var, |
| 45 | + updated_slot_vars) |
| 46 | + Note that slot_vars will be initialized to an empty dictionary |
| 47 | + for each variable, initial values should be handled in the |
| 48 | + update_fn. |
| 49 | + params: A dict, the parameters to pass to the construcor of the |
| 50 | + optimizer. Either a constant or a callable. This also passed to |
| 51 | + the optimizer_params in the update_fn. |
| 52 | + do_sparse: If True, test sparse update. Defaults to False, i.e., |
| 53 | + dense update. |
| 54 | + """ |
53 | 55 | for i, dtype in enumerate([tf.half, tf.float32, tf.float64]): |
54 | 56 | with self.session(graph=tf.Graph()): |
55 | 57 | # Initialize variables for numpy implementation. |
@@ -101,15 +103,15 @@ def doTest(self, optimizer, update_fn, params, do_sparse=False): |
101 | 103 |
|
102 | 104 | def doTestSparseRepeatedIndices(self, optimizer, params): |
103 | 105 | """Test for repeated indices in sparse updates. |
104 | | - |
| 106 | +
|
105 | 107 | This test verifies that an update with repeated indices is the same as |
106 | 108 | an update with two times the gradient. |
107 | 109 |
|
108 | 110 | Args: |
109 | | - optimizer: The tensorflow optimizer class to be tested. |
110 | | - params: A dict, the parameters to pass to the construcor of the |
111 | | - optimizer. Either a constant or a callable. This also passed to the |
112 | | - optimizer_params in the update_fn. |
| 111 | + optimizer: The tensorflow optimizer class to be tested. |
| 112 | + params: A dict, the parameters to pass to the construcor of the |
| 113 | + optimizer. Either a constant or a callable. This also passed to |
| 114 | + the optimizer_params in the update_fn. |
113 | 115 | """ |
114 | 116 | for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: |
115 | 117 | with self.cached_session(): |
@@ -145,7 +147,3 @@ def doTestSparseRepeatedIndices(self, optimizer, params): |
145 | 147 | self.assertAllClose( |
146 | 148 | self.evaluate(aggregated_update_var), |
147 | 149 | self.evaluate(repeated_index_update_var)) |
148 | | - |
149 | | - |
150 | | -if __name__ == "__main__": |
151 | | - tf.test.main() |
0 commit comments