@@ -441,6 +441,43 @@ def test_if_lr_finder_callback_already_configured():
441441 trainer .tune (model )
442442
443443
444+ def test_lr_finder_callback_restarting (tmpdir ):
445+ """Test that `LearningRateFinder` does not set restarting=True when loading checkpoint."""
446+
447+ class MyBoringModel (BoringModel ):
448+ def __init__ (self ):
449+ super ().__init__ ()
450+ self .learning_rate = 0.123
451+
452+ def configure_optimizers (self ):
453+ return torch .optim .SGD (self .parameters (), lr = self .learning_rate )
454+
455+ class CustomLearningRateFinder (LearningRateFinder ):
456+ milestones = (1 ,)
457+
458+ def lr_find (self , trainer , pl_module ) -> None :
459+ super ().lr_find (trainer , pl_module )
460+ assert not trainer .fit_loop .restarting
461+
462+ def on_train_epoch_start (self , trainer , pl_module ):
463+ if trainer .current_epoch in self .milestones or trainer .current_epoch == 0 :
464+ self .lr_find (trainer , pl_module )
465+
466+ model = MyBoringModel ()
467+ trainer = Trainer (
468+ default_root_dir = tmpdir ,
469+ max_epochs = 3 ,
470+ callbacks = [CustomLearningRateFinder (early_stop_threshold = None , update_attr = True )],
471+ limit_train_batches = 10 ,
472+ limit_val_batches = 0 ,
473+ limit_test_batches = 00 ,
474+ num_sanity_val_steps = 0 ,
475+ enable_model_summary = False ,
476+ )
477+
478+ trainer .fit (model )
479+
480+
444481@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
445482@RunIf (standalone = True )
446483def test_lr_finder_with_ddp (tmpdir ):
0 commit comments