1717
1818import torch
1919from torch .optim import Adam , Optimizer , SGD
20- from torch .optim .optimizer import _RequiredParameter
2120
2221from pytorch_lightning import Trainer
2322from pytorch_lightning .core .optimizer import LightningOptimizer
@@ -308,31 +307,17 @@ def configure_optimizers(self):
308307 assert zero_grad .call_count == max_iter
309308
310309
311- required = _RequiredParameter ()
312-
313-
314310class OptimizerWithHooks (Optimizer ):
315311
316- def __init__ (self , model , lr = required , u0 = required ):
317- if lr is not required and lr < 0.0 :
318- raise ValueError ("Invalid learning rate: {}" .format (lr ))
319-
320- defaults = dict (lr = lr )
321- self .steps = 0
322-
323- self .params = []
324-
312+ def __init__ (self , model ):
325313 self ._fwd_handles = []
326314 self ._bwd_handles = []
327-
328- self .model = model
329-
330- for _ , mod in model .named_modules (): # iterates over modules of model
315+ self .params = []
316+ for _ , mod in model .named_modules ():
331317 mod_class = mod .__class__ .__name__
332- if mod_class not in [ 'Linear' ]: # silently skips other layers
318+ if mod_class != 'Linear' :
333319 continue
334320
335- # save the inputs and gradients for the kfac matrix computation
336321 handle = mod .register_forward_pre_hook (self ._save_input ) # save the inputs
337322 self ._fwd_handles .append (handle ) # collect forward-save-input hooks in list
338323 handle = mod .register_backward_hook (self ._save_grad_output ) # save the gradients
@@ -347,21 +332,21 @@ def __init__(self, model, lr=required, u0=required):
347332 d = {'params' : params , 'mod' : mod , 'layer_type' : mod_class }
348333 self .params .append (d )
349334
350- super (OptimizerWithHooks , self ).__init__ (self .params , defaults )
335+ super (OptimizerWithHooks , self ).__init__ (self .params , { "lr" : 0.01 } )
351336
352337 def _save_input (self , mod , i ):
353338 """Saves input of layer"""
354339 if mod .training :
355340 self .state [mod ]['x' ] = i [0 ]
356341
357- def _save_grad_output (self , mod , grad_input , grad_output ):
342+ def _save_grad_output (self , mod , _ , grad_output ):
358343 """
359344 Saves grad on output of layer to
360345 grad is scaled with batch_size since gradient is spread over samples in mini batch
361346 """
362- bs = grad_output [0 ].shape [0 ] # batch_size
347+ batch_size = grad_output [0 ].shape [0 ]
363348 if mod .training :
364- self .state [mod ]['grad' ] = grad_output [0 ] * bs
349+ self .state [mod ]['grad' ] = grad_output [0 ] * batch_size
365350
366351 def step (self , closure = None ):
367352 closure ()
@@ -371,14 +356,11 @@ def step(self, closure=None):
371356 return True
372357
373358
374- def test_lightning_optimizer_dont_delete_wrapped_optimizer (tmpdir ):
359+ def test_lightning_optimizer_keeps_hooks (tmpdir ):
375360
376361 class TestModel (BoringModel ):
377-
378- def __init__ (self ):
379- super ().__init__ ()
380- self .count_on_train_batch_start = 0
381- self .count_on_train_batch_end = 0
362+ count_on_train_batch_start = 0
363+ count_on_train_batch_end = 0
382364
383365 def configure_optimizers (self ):
384366 return OptimizerWithHooks (self )
@@ -390,15 +372,11 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
390372
391373 def on_train_batch_end (self , outputs : Any , batch : Any , batch_idx : int , dataloader_idx : int ) -> None :
392374 self .count_on_train_batch_end += 1
393- # delete the lightning_optimizers
394- self .trainer ._lightning_optimizers = None
395- gc .collect ()
375+ del self .trainer ._lightning_optimizers
376+ gc .collect () # not necessary, just in case
396377
378+ trainer = Trainer (default_root_dir = tmpdir , limit_train_batches = 4 , limit_val_batches = 1 , max_epochs = 1 )
397379 model = TestModel ()
398- # Initialize a trainer
399- trainer = Trainer (limit_train_batches = 4 , limit_val_batches = 1 , max_epochs = 1 )
400-
401- # Train the model ⚡
402380 trainer .fit (model )
403381 assert model .count_on_train_batch_start == 4
404382 assert model .count_on_train_batch_end == 4
0 commit comments