Skip to content

Commit aadd2a9

Browse files
Sean Narenpre-commit-ci[bot]carmoccakaushikb11
authored
Load ckpt path when model provided in validate/test/predict (#8352)
* Change trainer loading behaviour for validate/test/predict * Fix * Fix/add tests * remove * Cleanups * Space * cleanups * Add CHANGELOG.md * Move after setup * Cleanups on logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remve * fix test * feedback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pytorch_lightning/trainer/properties.py Co-authored-by: Carlos Mocholí <[email protected]> * Feedback * Same fix * Same fix * Add test for behaviour, modify based on feedback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Wording * Apply suggestions from code review Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> * Cleanup docs * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Kaushik B <[email protected]> * feedback * Fixes to test API * Add carlos description * Move logic further * Move checkpoint connector logic Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent b256d6a commit aadd2a9

File tree

11 files changed

+93
-54
lines changed

11 files changed

+93
-54
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))
2929

3030

31-
-
31+
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))
32+
3233

3334

3435
-
@@ -164,6 +165,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
164165
- Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))
165166
- Added support for providing callables to the Lightning CLI instead of types ([#8400](https://github.com/PyTorchLightning/pytorch-lightning/pull/8400))
166167

168+
167169
### Changed
168170

169171
- Decoupled device parsing logic from Accelerator connector to Trainer ([#8180](https://github.com/PyTorchLightning/pytorch-lightning/pull/8180))

docs/source/common/test_set.rst

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,12 @@ To run the test set after training completes, use this method.
2020
trainer.fit(model)
2121
2222
# (1) load the best checkpoint automatically (lightning tracks this for you)
23-
trainer.test()
23+
trainer.test(ckpt_path='best')
2424
25-
# (2) don't load a checkpoint, instead use the model with the latest weights
26-
trainer.test(ckpt_path=None)
27-
28-
# (3) test using a specific checkpoint
25+
# (2) test using a specific checkpoint
2926
trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
3027
31-
# (4) test with an explicit model (will use this model and not load a checkpoint)
28+
# (3) test with an explicit model (will use this model and not load a checkpoint)
3229
trainer.test(model)
3330
3431
----------

pytorch_lightning/trainer/properties.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ class TrainerProperties(ABC):
6969
logger: LightningLoggerBase
7070
logger_connector: LoggerConnector
7171
state: TrainerState
72+
73+
# .validate() and .test() set this when they load a checkpoint
74+
validated_ckpt_path: Optional[str] = None
75+
tested_ckpt_path: Optional[str] = None
76+
predicted_ckpt_path: Optional[str] = None
7277
"""
7378
Accelerator properties
7479
"""
@@ -614,6 +619,15 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
614619
if self.predicting:
615620
return self.predict_loop
616621

622+
@property
623+
def _ckpt_path(self) -> Optional[str]:
624+
if self.state.fn == TrainerFn.VALIDATING:
625+
return self.validated_ckpt_path
626+
if self.state.fn == TrainerFn.TESTING:
627+
return self.tested_ckpt_path
628+
if self.state.fn == TrainerFn.PREDICTING:
629+
return self.predicted_ckpt_path
630+
617631
"""
618632
Logging properties
619633
"""

pytorch_lightning/trainer/trainer.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -489,15 +489,10 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None:
489489
self.test_dataloaders = None
490490
self.val_dataloaders = None
491491

492-
# .validate() and .test() set this when they load a checkpoint
493-
self.validated_ckpt_path = None
494-
self.tested_ckpt_path = None
495-
496492
# when true, print evaluation results in .validate() and .test()
497493
self.verbose_evaluate = True
498494

499495
self.num_predict_batches = []
500-
self.predicted_ckpt_path = None
501496

502497
def fit(
503498
self,
@@ -559,7 +554,7 @@ def validate(
559554
self,
560555
model: Optional["pl.LightningModule"] = None,
561556
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
562-
ckpt_path: Optional[str] = "best",
557+
ckpt_path: Optional[str] = None,
563558
verbose: bool = True,
564559
datamodule: Optional[LightningDataModule] = None,
565560
val_dataloaders=None, # noqa TODO: remove with 1.6
@@ -574,8 +569,8 @@ def validate(
574569
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples.
575570
576571
ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
577-
If ``None``, use the current weights of the model.
578-
When the model is given as argument, this parameter will not apply.
572+
If ``None`` and the model instance was passed, use the current weights.
573+
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
579574
580575
verbose: If True, prints the validation results.
581576
@@ -621,8 +616,9 @@ def validate(
621616
# links data to the trainer
622617
self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
623618

624-
if not model_provided:
625-
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)
619+
self.validated_ckpt_path = self.__set_ckpt_path(
620+
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
621+
)
626622

627623
# run validate
628624
results = self._run(model)
@@ -636,7 +632,7 @@ def test(
636632
self,
637633
model: Optional["pl.LightningModule"] = None,
638634
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
639-
ckpt_path: Optional[str] = "best",
635+
ckpt_path: Optional[str] = None,
640636
verbose: bool = True,
641637
datamodule: Optional[LightningDataModule] = None,
642638
test_dataloaders=None, # noqa TODO: remove with 1.6
@@ -652,8 +648,8 @@ def test(
652648
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples.
653649
654650
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
655-
If ``None``, use the current weights of the model.
656-
When the model is given as argument, this parameter will not apply.
651+
If ``None`` and the model instance was passed, use the current weights.
652+
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
657653
658654
verbose: If True, prints the test results.
659655
@@ -699,8 +695,9 @@ def test(
699695
# links data to the trainer
700696
self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
701697

702-
if not model_provided:
703-
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)
698+
self.tested_ckpt_path = self.__set_ckpt_path(
699+
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
700+
)
704701

705702
# run test
706703
results = self._run(model)
@@ -716,7 +713,7 @@ def predict(
716713
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
717714
datamodule: Optional[LightningDataModule] = None,
718715
return_predictions: Optional[bool] = None,
719-
ckpt_path: Optional[str] = "best",
716+
ckpt_path: Optional[str] = None,
720717
) -> Optional[_PREDICT_OUTPUT]:
721718
r"""
722719
@@ -734,9 +731,9 @@ def predict(
734731
return_predictions: Whether to return predictions.
735732
``True`` by default except when an accelerator that spawns processes is used (not supported).
736733
737-
ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict.
738-
If ``None``, use the current weights of the model.
739-
When the model is given as argument, this parameter will not apply.
734+
ckpt_path: Either ``best`` or path to the checkpoint you wish to predict.
735+
If ``None`` and the model instance was passed, use the current weights.
736+
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
740737
741738
Returns:
742739
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
@@ -770,8 +767,9 @@ def predict(
770767
# links data to the trainer
771768
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
772769

773-
if not model_provided:
774-
self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path)
770+
self.predicted_ckpt_path = self.__set_ckpt_path(
771+
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
772+
)
775773

776774
results = self._run(model)
777775

@@ -856,6 +854,15 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
856854
self.data_connector.prepare_data(model)
857855
self.callback_connector._attach_model_callbacks(model, self)
858856

857+
if self._ckpt_path:
858+
# only one process running at this point for TPUs, as spawn isn't triggered yet
859+
# todo: move this logic internally within the barrier.
860+
if not self._device_type == DeviceType.TPU:
861+
self.training_type_plugin.barrier()
862+
863+
rank_zero_info(f"Loading checkpoint from {self._ckpt_path}")
864+
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
865+
859866
# ----------------------------
860867
# SET UP TRAINING
861868
# ----------------------------
@@ -910,7 +917,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
910917

911918
# plugin will setup fitting (e.g. ddp will launch child processes)
912919
self._pre_dispatch()
913-
914920
# restore optimizers, etc.
915921
self.checkpoint_connector.restore_training_state()
916922

@@ -1126,12 +1132,22 @@ def _run_sanity_check(self, ref_model):
11261132
# restore the previous stage when the sanity check if finished
11271133
self.state.stage = stage
11281134

1129-
def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
1130-
if ckpt_path is None:
1135+
def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]:
1136+
if model_provided and ckpt_path is None:
1137+
# use passed model to function without loading weights
11311138
return
11321139

11331140
fn = self.state.fn.value
11341141

1142+
if model_connected and ckpt_path is None:
1143+
rank_zero_warn(
1144+
f"`.{fn}(ckpt_path=None)` was called without a model. "
1145+
"The best model of the previous `fit` call will be used. "
1146+
f"You can pass `{fn}(ckpt_path='best')` to avoid this warning "
1147+
"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
1148+
)
1149+
ckpt_path = "best"
1150+
11351151
if ckpt_path == "best":
11361152
# if user requests the best checkpoint but we don't have it, error
11371153
if not self.checkpoint_callback.best_model_path:
@@ -1151,13 +1167,6 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
11511167
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
11521168
f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
11531169
)
1154-
1155-
# only one process running at this point for TPUs, as spawn isn't triggered yet
1156-
# todo: move this logic internally within the barrier.
1157-
if not self._device_type == DeviceType.TPU:
1158-
self.training_type_plugin.barrier()
1159-
1160-
self.checkpoint_connector.restore_model_weights(ckpt_path)
11611170
return ckpt_path
11621171

11631172
def _call_setup_hook(self, model: "pl.LightningModule") -> None:

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,6 @@ def configure_callbacks(self):
9898
callbacks_after = trainer.callbacks.copy()
9999
assert callbacks_after == callbacks_after_fit
100100

101-
trainer_fn(ckpt_path=None)
101+
trainer_fn(model)
102102
callbacks_after = trainer.callbacks.copy()
103103
assert callbacks_after == callbacks_after_fit

tests/models/test_hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def predict_dataloader(self):
796796

797797
trainer.fit(model)
798798
assert trainer.state.finished, f"Training failed with {trainer.state}"
799-
trainer.test(ckpt_path=None)
799+
trainer.test(model)
800800

801801
preds = trainer.predict(model)
802802
assert len(preds) == 2

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
360360
# fit model
361361
trainer = Trainer(**trainer_options)
362362
trainer.fit(model)
363-
trainer.test(ckpt_path=None)
363+
trainer.test(model)
364364

365365
# correct result and ok accuracy
366366
assert trainer.state.finished, f"Training failed with {trainer.state}"

tests/trainer/flags/test_fast_dev_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _make_fast_dev_run_assertions(trainer, model):
107107
train_val_step_model = FastDevRunModel()
108108
trainer = Trainer(**trainer_config)
109109
trainer.fit(train_val_step_model)
110-
trainer.test(ckpt_path=None)
110+
trainer.test(train_val_step_model)
111111

112112
assert trainer.state.finished, f"Training failed with {trainer.state}"
113113
_make_fast_dev_run_assertions(trainer, train_val_step_model)
@@ -120,7 +120,7 @@ def _make_fast_dev_run_assertions(trainer, model):
120120

121121
trainer = Trainer(**trainer_config)
122122
trainer.fit(train_step_only_model)
123-
trainer.test(ckpt_path=None)
123+
trainer.test(train_step_only_model)
124124

125125
assert trainer.state.finished, f"Training failed with {trainer.state}"
126126
_make_fast_dev_run_assertions(trainer, train_step_only_model)

tests/trainer/logging_/test_logger_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_dataloader(self):
201201
default_root_dir=tmpdir, accelerator="dp", gpus=2, limit_train_batches=2, limit_val_batches=2, max_epochs=1
202202
)
203203
trainer.fit(model)
204-
trainer.test(model, ckpt_path=None)
204+
trainer.test(model)
205205

206206

207207
def test_can_return_tensor_with_more_than_one_element(tmpdir):

tests/trainer/test_dataloaders.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
439439
assert trainer.num_training_batches == expected_train_batches
440440
assert trainer.num_val_batches == expected_val_batches
441441

442-
trainer.test(ckpt_path=None)
442+
trainer.test(model)
443443
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
444444
assert trainer.num_test_batches == expected_test_batches
445445

@@ -474,7 +474,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v
474474
# -------------------------------------------
475475
assert trainer.num_training_batches == limit_train_batches
476476
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
477-
trainer.test(ckpt_path=None)
477+
trainer.test(model)
478478

479479
# when the limit is greater than the number of test batches it should be the num in loaders
480480
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
@@ -549,7 +549,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
549549
assert trainer.num_training_batches == fast_dev_run
550550
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)
551551

552-
trainer.test(ckpt_path=None)
552+
trainer.test(model)
553553
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)
554554

555555
# verify sanity check batches match as expected
@@ -685,6 +685,8 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
685685
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
686686
):
687687
if stage == "test":
688+
if ckpt_path in ("specific", "best"):
689+
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
688690
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
689691
trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path)
690692
else:
@@ -722,6 +724,8 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
722724
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
723725
):
724726
if stage == "test":
727+
if ckpt_path in ("specific", "best"):
728+
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
725729
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
726730
trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path)
727731
else:
@@ -950,7 +954,7 @@ def test_dataloader_distributed_sampler(tmpdir):
950954
callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))],
951955
)
952956
trainer.fit(model)
953-
trainer.test(ckpt_path=None)
957+
trainer.test(model)
954958

955959

956960
class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
@@ -1444,7 +1448,7 @@ def predict_dataloader(self):
14441448

14451449
trainer.fit(model)
14461450
assert trainer.state.finished, f"Training failed with {trainer.state}"
1447-
trainer.test(ckpt_path=None)
1451+
trainer.test(model)
14481452

14491453
preds = trainer.predict(model)
14501454
assert len(preds) == 2

0 commit comments

Comments
 (0)