1515import logging as log
1616import os
1717import pickle
18+ from copy import deepcopy
1819
1920import cloudpickle
2021import pytest
2425
2526import tests .base .develop_pipelines as tpipes
2627import tests .base .develop_utils as tutils
27- from pytorch_lightning import Trainer , LightningModule , Callback
28+ from pytorch_lightning import Trainer , LightningModule , Callback , seed_everything
2829from pytorch_lightning .callbacks import ModelCheckpoint
2930from tests .base import EvalModelTemplate , GenericEvalModelTemplate , TrialMNIST
3031
@@ -51,24 +52,90 @@ def on_train_end(self, trainer, pl_module):
5152 self ._check_properties (trainer , pl_module )
5253
5354
54- def test_resume_from_checkpoint (tmpdir ):
55+ def test_model_properties_resume_from_checkpoint (tmpdir ):
5556 """ Test that properties like `current_epoch` and `global_step`
5657 in model and trainer are always the same. """
5758 model = EvalModelTemplate ()
5859 checkpoint_callback = ModelCheckpoint (dirpath = tmpdir , monitor = "early_stop_on" , save_last = True )
5960 trainer_args = dict (
6061 default_root_dir = tmpdir ,
61- max_epochs = 2 ,
62+ max_epochs = 1 ,
6263 logger = False ,
63- checkpoint_callback = checkpoint_callback ,
64- callbacks = [ModelTrainerPropertyParity ()] # this performs the assertions
64+ callbacks = [checkpoint_callback , ModelTrainerPropertyParity ()] # this performs the assertions
6565 )
6666 trainer = Trainer (** trainer_args )
6767 trainer .fit (model )
68+
69+ trainer_args .update (max_epochs = 2 )
6870 trainer = Trainer (** trainer_args , resume_from_checkpoint = str (tmpdir / "last.ckpt" ))
6971 trainer .fit (model )
7072
7173
74+ class CaptureCallbacksBeforeTraining (Callback ):
75+ callbacks = []
76+
77+ def on_train_start (self , trainer , pl_module ):
78+ self .callbacks = deepcopy (trainer .callbacks )
79+
80+
81+ def test_callbacks_state_resume_from_checkpoint (tmpdir ):
82+ """ Test that resuming from a checkpoint restores callbacks that persist state. """
83+ model = EvalModelTemplate ()
84+ callback_capture = CaptureCallbacksBeforeTraining ()
85+
86+ def get_trainer_args ():
87+ checkpoint = ModelCheckpoint (dirpath = tmpdir , monitor = "early_stop_on" , save_last = True )
88+ trainer_args = dict (
89+ default_root_dir = tmpdir ,
90+ max_steps = 1 ,
91+ logger = False ,
92+ callbacks = [
93+ checkpoint ,
94+ callback_capture ,
95+ ]
96+ )
97+ assert checkpoint .best_model_path == ""
98+ assert checkpoint .best_model_score == 0
99+ return trainer_args
100+
101+ # initial training
102+ trainer = Trainer (** get_trainer_args ())
103+ trainer .fit (model )
104+ callbacks_before_resume = deepcopy (trainer .callbacks )
105+
106+ # resumed training
107+ trainer = Trainer (** get_trainer_args (), resume_from_checkpoint = str (tmpdir / "last.ckpt" ))
108+ trainer .fit (model )
109+
110+ assert len (callbacks_before_resume ) == len (callback_capture .callbacks )
111+
112+ for before , after in zip (callbacks_before_resume , callback_capture .callbacks ):
113+ if isinstance (before , ModelCheckpoint ):
114+ assert before .best_model_path == after .best_model_path
115+ assert before .best_model_score == after .best_model_score
116+
117+
118+ def test_callbacks_references_resume_from_checkpoint (tmpdir ):
119+ """ Test that resuming from a checkpoint sets references as expected. """
120+ model = EvalModelTemplate ()
121+ args = {'default_root_dir' : tmpdir , 'max_steps' : 1 , 'logger' : False }
122+
123+ # initial training
124+ checkpoint = ModelCheckpoint (dirpath = tmpdir , monitor = "early_stop_on" , save_last = True )
125+ trainer = Trainer (** args , callbacks = [checkpoint ])
126+ assert checkpoint is trainer .callbacks [0 ] is trainer .checkpoint_callback
127+ trainer .fit (model )
128+
129+ # resumed training
130+ new_checkpoint = ModelCheckpoint (dirpath = tmpdir , monitor = "early_stop_on" , save_last = True )
131+ # pass in a new checkpoint object, which should take
132+ # precedence over the one in the last.ckpt file
133+ trainer = Trainer (** args , callbacks = [new_checkpoint ], resume_from_checkpoint = str (tmpdir / "last.ckpt" ))
134+ assert checkpoint is not new_checkpoint
135+ assert new_checkpoint is trainer .callbacks [0 ] is trainer .checkpoint_callback
136+ trainer .fit (model )
137+
138+
72139@pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
73140def test_running_test_pretrained_model_distrib_dp (tmpdir ):
74141 """Verify `test()` on pretrained model."""
0 commit comments