@@ -224,31 +224,39 @@ def test_correct_step_and_epoch(tmpdir):
224224 model = BoringModel ()
225225 first_max_epochs = 2
226226 train_batches = 2
227- trainer = Trainer (default_root_dir = tmpdir , max_epochs = first_max_epochs , limit_train_batches = train_batches )
227+ trainer = Trainer (
228+ default_root_dir = tmpdir , max_epochs = first_max_epochs , limit_train_batches = train_batches , limit_val_batches = 0
229+ )
228230 assert trainer .current_epoch == 0
229231 assert trainer .global_step == 0
230232
231233 trainer .fit (model )
232234 assert trainer .current_epoch == first_max_epochs
233235 assert trainer .global_step == first_max_epochs * train_batches
234236
235- ckpt = str (tmpdir / "model.ckpt" )
236- trainer .save_checkpoint (ckpt )
237+ ckpt_path = str (tmpdir / "model.ckpt" )
238+ trainer .save_checkpoint (ckpt_path )
239+
240+ ckpt = torch .load (ckpt_path )
241+ assert ckpt ["epoch" ] == first_max_epochs
237242 # TODO(@carmocca): should not need `+1`
238- assert torch . load ( ckpt ) ["global_step" ] == first_max_epochs * train_batches + 1
243+ assert ckpt ["global_step" ] == first_max_epochs * train_batches + 1
239244
240- max_epochs = 4
241- trainer = Trainer (default_root_dir = tmpdir , max_epochs = max_epochs , limit_train_batches = train_batches )
245+ max_epochs = first_max_epochs + 2
246+ trainer = Trainer (
247+ default_root_dir = tmpdir , max_epochs = max_epochs , limit_train_batches = train_batches , limit_val_batches = 0
248+ )
242249 # the ckpt state is not loaded at this point
243250 assert trainer .current_epoch == 0
244251 assert trainer .global_step == 0
245252
246253 class TestModel (BoringModel ):
247254 def on_pretrain_routine_end (self ) -> None :
255+ assert self .trainer .current_epoch == first_max_epochs
248256 # TODO(@carmocca): should not need `+1`
249257 assert self .trainer .global_step == first_max_epochs * train_batches + 1
250258
251- trainer .fit (TestModel (), ckpt_path = ckpt )
259+ trainer .fit (TestModel (), ckpt_path = ckpt_path )
252260 assert trainer .current_epoch == max_epochs
253261 # TODO(@carmocca): should not need `+1`
254262 assert trainer .global_step == max_epochs * train_batches + 1
0 commit comments