@@ -69,13 +69,13 @@ def doTest(self, optimizer, update_fn, do_sparse=False,
6969 var1 = tf .Variable (var1_np , name = "var1_%d" % i )
7070 if do_sparse :
7171 grads0_np_indices = np .array ([0 , 1 ], dtype = np .int32 )
72- grads0 = tf .IndexedSlices (tf . constant ( grads0_np ),
73- tf .constant (grads0_np_indices ),
74- tf .constant ([2 ]))
72+ grads0 = tf .IndexedSlices (
73+ tf . constant ( grads0_np ), tf .constant (grads0_np_indices ),
74+ tf .constant ([2 ]))
7575 grads1_np_indices = np .array ([0 , 1 ], dtype = np .int32 )
76- grads1 = tf .IndexedSlices (tf . constant ( grads1_np ),
77- tf .constant (grads1_np_indices ),
78- tf .constant ([2 ]))
76+ grads1 = tf .IndexedSlices (
77+ tf . constant ( grads1_np ), tf .constant (grads1_np_indices ),
78+ tf .constant ([2 ]))
7979 else :
8080 grads0 = tf .constant (grads0_np )
8181 grads1 = tf .constant (grads1_np )
@@ -94,12 +94,10 @@ def doTest(self, optimizer, update_fn, do_sparse=False,
9494 opt .apply_gradients (zip ([grads0 , grads1 ], [var0 , var1 ]))
9595 else :
9696 self .evaluate (update )
97- var0_np , np_slot_vars0 = update_fn (var0_np , grads0_np ,
98- np_slot_vars0 ,
99- ** optimizer_kwargs )
100- var1_np , np_slot_vars1 = update_fn (var1_np , grads1_np ,
101- np_slot_vars1 ,
102- ** optimizer_kwargs )
97+ var0_np , np_slot_vars0 = update_fn (
98+ var0_np , grads0_np , np_slot_vars0 , ** optimizer_kwargs )
99+ var1_np , np_slot_vars1 = update_fn (
100+ var1_np , grads1_np , np_slot_vars1 , ** optimizer_kwargs )
103101 # Validate updated params
104102 self .assertAllCloseAccordingToType (var0_np ,
105103 self .evaluate (var0 ))
@@ -129,16 +127,15 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs):
129127 tf .constant ([0.2 ], shape = [1 , 1 ], dtype = dtype ),
130128 tf .constant ([1 ]), tf .constant ([2 , 1 ]))
131129 opt_repeated = optimizer (** optimizer_kwargs )
132- repeated_update = opt_repeated .apply_gradients ([
133- (grad_repeated_index , repeated_index_update_var )
134- ])
130+ repeated_update = opt_repeated .apply_gradients (
131+ [(grad_repeated_index , repeated_index_update_var )])
135132 opt_aggregated = optimizer (** optimizer_kwargs )
136- aggregated_update = opt_aggregated .apply_gradients ([
137- (grad_aggregated , aggregated_update_var )
138- ])
133+ aggregated_update = opt_aggregated .apply_gradients (
134+ [(grad_aggregated , aggregated_update_var )])
139135 self .evaluate (tf .compat .v1 .global_variables_initializer ())
140- self .assertAllClose (self .evaluate (aggregated_update_var ),
141- self .evaluate (repeated_index_update_var ))
136+ self .assertAllClose (
137+ self .evaluate (aggregated_update_var ),
138+ self .evaluate (repeated_index_update_var ))
142139 for _ in range (3 ):
143140 if not tf .executing_eagerly ():
144141 self .evaluate (repeated_update )
@@ -148,8 +145,9 @@ def doTestSparseRepeatedIndices(self, optimizer, **optimizer_kwargs):
148145 repeated_index_update_var )])
149146 opt_aggregated .apply_gradients ([(grad_aggregated ,
150147 aggregated_update_var )])
151- self .assertAllClose (self .evaluate (aggregated_update_var ),
152- self .evaluate (repeated_index_update_var ))
148+ self .assertAllClose (
149+ self .evaluate (aggregated_update_var ),
150+ self .evaluate (repeated_index_update_var ))
153151
154152
155153def adamw_update_numpy (param , grad_t , slot_vars , learning_rate , beta_1 , beta_2 ,
@@ -162,8 +160,8 @@ def adamw_update_numpy(param, grad_t, slot_vars, learning_rate, beta_1, beta_2,
162160 lr_t = lr * np .sqrt (1 - beta2 ** t ) / (1 - beta1 ** t )
163161 slot_vars ["m" ] = beta1 * slot_vars .get ("m" , 0 ) + (1 - beta1 ) * grad_t
164162 slot_vars ["v" ] = beta2 * slot_vars .get ("v" , 0 ) + (1 - beta2 ) * grad_t ** 2
165- param_t = (param * (1 - wd ) - lr_t * slot_vars [ "m" ] /
166- (np .sqrt (slot_vars ["v" ]) + eps ))
163+ param_t = (param * (1 - wd ) -
164+ lr_t * slot_vars [ "m" ] / (np .sqrt (slot_vars ["v" ]) + eps ))
167165 slot_vars ["t" ] = t
168166 return param_t , slot_vars
169167
@@ -185,42 +183,46 @@ class AdamWTest(OptimizerTestBase):
185183
186184 @test_utils .run_in_graph_and_eager_modes (reset_test = True )
187185 def testSparse (self ):
188- self .doTest (self .optimizer ,
189- adamw_update_numpy ,
190- do_sparse = True ,
191- learning_rate = 0.001 ,
192- beta_1 = 0.9 ,
193- beta_2 = 0.999 ,
194- epsilon = 1e-8 ,
195- weight_decay = WEIGHT_DECAY )
186+ self .doTest (
187+ self .optimizer ,
188+ adamw_update_numpy ,
189+ do_sparse = True ,
190+ learning_rate = 0.001 ,
191+ beta_1 = 0.9 ,
192+ beta_2 = 0.999 ,
193+ epsilon = 1e-8 ,
194+ weight_decay = WEIGHT_DECAY )
196195
197196 @test_utils .run_in_graph_and_eager_modes (reset_test = True )
198197 def testSparseRepeatedIndices (self ):
199- self .doTestSparseRepeatedIndices (self .optimizer ,
200- learning_rate = 0.001 ,
201- beta_1 = 0.9 ,
202- beta_2 = 0.999 ,
203- epsilon = 1e-8 ,
204- weight_decay = WEIGHT_DECAY )
198+ self .doTestSparseRepeatedIndices (
199+ self .optimizer ,
200+ learning_rate = 0.001 ,
201+ beta_1 = 0.9 ,
202+ beta_2 = 0.999 ,
203+ epsilon = 1e-8 ,
204+ weight_decay = WEIGHT_DECAY )
205205
206206 @test_utils .run_in_graph_and_eager_modes (reset_test = True )
207207 def testBasic (self ):
208- self .doTest (self .optimizer ,
209- adamw_update_numpy ,
210- learning_rate = 0.001 ,
211- beta_1 = 0.9 ,
212- beta_2 = 0.999 ,
213- epsilon = 1e-8 ,
214- weight_decay = WEIGHT_DECAY )
208+ self .doTest (
209+ self .optimizer ,
210+ adamw_update_numpy ,
211+ learning_rate = 0.001 ,
212+ beta_1 = 0.9 ,
213+ beta_2 = 0.999 ,
214+ epsilon = 1e-8 ,
215+ weight_decay = WEIGHT_DECAY )
215216
216217 def testBasicCallableParams (self ):
217- self .doTest (self .optimizer ,
218- adamw_update_numpy ,
219- learning_rate = lambda : 0.001 ,
220- beta_1 = lambda : 0.9 ,
221- beta_2 = lambda : 0.999 ,
222- epsilon = lambda : 1e-8 ,
223- weight_decay = lambda : WEIGHT_DECAY )
218+ self .doTest (
219+ self .optimizer ,
220+ adamw_update_numpy ,
221+ learning_rate = lambda : 0.001 ,
222+ beta_1 = lambda : 0.9 ,
223+ beta_2 = lambda : 0.999 ,
224+ epsilon = lambda : 1e-8 ,
225+ weight_decay = lambda : WEIGHT_DECAY )
224226
225227
226228class SGDWTest (OptimizerTestBase ):
@@ -229,34 +231,38 @@ class SGDWTest(OptimizerTestBase):
229231
230232 @test_utils .run_in_graph_and_eager_modes (reset_test = True )
231233 def testSparse (self ):
232- self .doTest (self .optimizer ,
233- sgdw_update_numpy ,
234- do_sparse = True ,
235- learning_rate = 0.001 ,
236- momentum = 0.9 ,
237- weight_decay = WEIGHT_DECAY )
234+ self .doTest (
235+ self .optimizer ,
236+ sgdw_update_numpy ,
237+ do_sparse = True ,
238+ learning_rate = 0.001 ,
239+ momentum = 0.9 ,
240+ weight_decay = WEIGHT_DECAY )
238241
239242 @test_utils .run_in_graph_and_eager_modes (reset_test = True )
240243 def testSparseRepeatedIndices (self ):
241- self .doTestSparseRepeatedIndices (self .optimizer ,
242- learning_rate = 0.001 ,
243- momentum = 0.9 ,
244- weight_decay = WEIGHT_DECAY )
244+ self .doTestSparseRepeatedIndices (
245+ self .optimizer ,
246+ learning_rate = 0.001 ,
247+ momentum = 0.9 ,
248+ weight_decay = WEIGHT_DECAY )
245249
246250 @test_utils .run_in_graph_and_eager_modes (reset_test = True )
247251 def testBasic (self ):
248- self .doTest (self .optimizer ,
249- sgdw_update_numpy ,
250- learning_rate = 0.001 ,
251- momentum = 0.9 ,
252- weight_decay = WEIGHT_DECAY )
252+ self .doTest (
253+ self .optimizer ,
254+ sgdw_update_numpy ,
255+ learning_rate = 0.001 ,
256+ momentum = 0.9 ,
257+ weight_decay = WEIGHT_DECAY )
253258
254259 def testBasicCallableParams (self ):
255- self .doTest (self .optimizer ,
256- sgdw_update_numpy ,
257- learning_rate = lambda : 0.001 ,
258- momentum = lambda : 0.9 ,
259- weight_decay = lambda : WEIGHT_DECAY )
260+ self .doTest (
261+ self .optimizer ,
262+ sgdw_update_numpy ,
263+ learning_rate = lambda : 0.001 ,
264+ momentum = lambda : 0.9 ,
265+ weight_decay = lambda : WEIGHT_DECAY )
260266
261267
262268class ExtendWithWeightDecayTest (SGDWTest ):
0 commit comments