1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from unittest .mock import patch , DEFAULT
14+ import gc
15+ from typing import Any
16+ from unittest .mock import DEFAULT , patch
1517
1618import torch
1719from torch .optim import Adam , Optimizer , SGD
@@ -188,6 +190,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
188190 """
189191 Test overriding zero_grad works in automatic_optimization
190192 """
193+
191194 class TestModel (BoringModel ):
192195
193196 def training_step (self , batch , batch_idx , optimizer_idx = None ):
@@ -281,7 +284,9 @@ def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir):
281284 Test zero_grad is called the same number of times as LBFGS requires
282285 for reevaluation of the loss in automatic_optimization.
283286 """
287+
284288 class TestModel (BoringModel ):
289+
285290 def configure_optimizers (self ):
286291 return torch .optim .LBFGS (self .parameters ())
287292
@@ -300,3 +305,78 @@ def configure_optimizers(self):
300305 lbfgs = model .optimizers ()
301306 max_iter = lbfgs .param_groups [0 ]["max_iter" ]
302307 assert zero_grad .call_count == max_iter
308+
309+
310+ class OptimizerWithHooks (Optimizer ):
311+
312+ def __init__ (self , model ):
313+ self ._fwd_handles = []
314+ self ._bwd_handles = []
315+ self .params = []
316+ for _ , mod in model .named_modules ():
317+ mod_class = mod .__class__ .__name__
318+ if mod_class != 'Linear' :
319+ continue
320+
321+ handle = mod .register_forward_pre_hook (self ._save_input ) # save the inputs
322+ self ._fwd_handles .append (handle ) # collect forward-save-input hooks in list
323+ handle = mod .register_backward_hook (self ._save_grad_output ) # save the gradients
324+ self ._bwd_handles .append (handle ) # collect backward-save-grad hook in list
325+
326+ # save the parameters
327+ params = [mod .weight ]
328+ if mod .bias is not None :
329+ params .append (mod .bias )
330+
331+ # save a param_group for each module
332+ d = {'params' : params , 'mod' : mod , 'layer_type' : mod_class }
333+ self .params .append (d )
334+
335+ super (OptimizerWithHooks , self ).__init__ (self .params , {"lr" : 0.01 })
336+
337+ def _save_input (self , mod , i ):
338+ """Saves input of layer"""
339+ if mod .training :
340+ self .state [mod ]['x' ] = i [0 ]
341+
342+ def _save_grad_output (self , mod , _ , grad_output ):
343+ """
344+ Saves grad on output of layer to
345+ grad is scaled with batch_size since gradient is spread over samples in mini batch
346+ """
347+ batch_size = grad_output [0 ].shape [0 ]
348+ if mod .training :
349+ self .state [mod ]['grad' ] = grad_output [0 ] * batch_size
350+
351+ def step (self , closure = None ):
352+ closure ()
353+ for group in self .param_groups :
354+ _ = self .state [group ['mod' ]]['x' ]
355+ _ = self .state [group ['mod' ]]['grad' ]
356+ return True
357+
358+
359+ def test_lightning_optimizer_keeps_hooks (tmpdir ):
360+
361+ class TestModel (BoringModel ):
362+ count_on_train_batch_start = 0
363+ count_on_train_batch_end = 0
364+
365+ def configure_optimizers (self ):
366+ return OptimizerWithHooks (self )
367+
368+ def on_train_batch_start (self , batch : Any , batch_idx : int , dataloader_idx : int ) -> None :
369+ self .count_on_train_batch_start += 1
370+ optimizer = self .optimizers (use_pl_optimizer = False )
371+ assert len (optimizer ._fwd_handles ) == 1
372+
373+ def on_train_batch_end (self , outputs : Any , batch : Any , batch_idx : int , dataloader_idx : int ) -> None :
374+ self .count_on_train_batch_end += 1
375+ del self .trainer ._lightning_optimizers
376+ gc .collect () # not necessary, just in case
377+
378+ trainer = Trainer (default_root_dir = tmpdir , limit_train_batches = 4 , limit_val_batches = 1 , max_epochs = 1 )
379+ model = TestModel ()
380+ trainer .fit (model )
381+ assert model .count_on_train_batch_start == 4
382+ assert model .count_on_train_batch_end == 4
0 commit comments