Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))


- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))


Expand Down
47 changes: 27 additions & 20 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,31 +955,38 @@ def __load_ckpt_weights(
model,
ckpt_path: Optional[str] = None,
) -> Optional[str]:
# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
if ckpt_path is None:
return

fn = self.state.value

if ckpt_path == 'best':
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do'
f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.'
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path

if not ckpt_path:
raise MisconfigurationException(
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)

# load best weights
if ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path

if not ckpt_path:
fn = self.state.value
raise MisconfigurationException(
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)
# only one process running at this point for TPUs, as spawn isn't triggered yet
if self._device_type != DeviceType.TPU:
self.training_type_plugin.barrier()

# only one process running at this point for TPUs, as spawn isn't triggered yet
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
return ckpt_path

def predict(
Expand Down
9 changes: 9 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,3 +1777,12 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None:

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()])
trainer.fit(model, datamodule=dm)


def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)

with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()