diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 430744a01a1e7..e0153c05732fa 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -317,6 +317,23 @@ def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModul return self.save_checkpoint(trainer) + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + """ + Save a checkpoint when training stops. + + This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during + training/validation steps or end of epochs are not guaranteed to be available at this stage. + """ + if self._should_skip_saving_checkpoint(trainer) or not self.save_last: + return + if self.verbose: + rank_zero_info("Saving latest checkpoint...") + # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates + monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1) + trainer.train_loop.global_step -= 1 + self._save_last_checkpoint(trainer, monitor_candidates) + trainer.train_loop.global_step += 1 + def on_save_checkpoint( self, trainer: 'pl.Trainer', diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 7df0d1445e3b3..ebded386458b3 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -22,7 +22,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import rank_zero_info log = logging.getLogger(__name__) @@ -227,14 +226,6 @@ def advance(self) -> None: self.global_step += 1 def on_advance_end(self) -> None: - """Updates the LR schedulers and does some internal bookkeeping""" - if self.epoch_loop.batches_seen != 0: - did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip - if did_train_only: - self.global_step -= 1 - self._check_checkpoint_callback(True) - self.global_step += 1 - self.epoch_progress.increment_completed() def on_run_end(self) -> None: @@ -245,13 +236,6 @@ def on_run_end(self) -> None: # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 self.current_epoch -= 1 - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.epoch_loop.global_step -= 1 - # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 - self._check_checkpoint_callback(should_update=True, is_last=True) - self.epoch_loop.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -271,19 +255,5 @@ def should_accumulate(self) -> bool: """Whether the gradients should be accumulated""" return self.epoch_loop.batch_loop.should_accumulate() - def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False): - """Checks if checkpointing needs to be done""" - # TODO: bake this logic into the ModelCheckpoint callback - if should_update: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 9e85d0cb0caf6..4aae78e622ee7 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -50,7 +50,7 @@ def test_mc_called(tmpdir): @mock.patch('torch.save') @pytest.mark.parametrize( ['epochs', 'val_check_interval', 'expected'], - [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)], + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 6)], ) def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int): @@ -74,9 +74,10 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter (1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), - (2, 2, 0.3, 7), + (2, 2, 0.3, 6), ]) -def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int): +@pytest.mark.parametrize("save_last", (False, True)) +def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool): class TestModel(BoringModel): @@ -94,7 +95,7 @@ def training_step(self, batch, batch_idx): model = TestModel() trainer = Trainer( - callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k)], + callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k, save_last=save_last)], default_root_dir=tmpdir, max_epochs=epochs, weights_summary=None, @@ -102,7 +103,9 @@ def training_step(self, batch, batch_idx): ) trainer.fit(model) - # make sure types are correct + if save_last: + # last epochs are saved every step (so double the save calls) and once `on_train_end` + expected = expected * 2 + 1 assert save_mock.call_count == expected @@ -115,7 +118,7 @@ def test_top_k_ddp_0(save_mock, tmpdir): @mock.patch('torch.save') @RunIf(special=True, min_gpus=2) def test_top_k_ddp_1(save_mock, tmpdir): - _top_k_ddp(save_mock, tmpdir, k=2, epochs=2, val_check_interval=0.3, expected=5) + _top_k_ddp(save_mock, tmpdir, k=2, epochs=2, val_check_interval=0.3, expected=4) def _top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 173059ad33bcd..a85b5aa8ed3ab 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -247,14 +247,7 @@ def configure_optimizers(self): ckpt_files = list(Path(tmpdir).glob('*.ckpt')) lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates - # on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the - # end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt - additional_ckpt, additional_ckpt_path = False, None - if not epoch_aligned: - additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0] - additional_ckpt = True - - assert len(ckpt_files) == len(model.scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt + assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs assert len(lr_scheduler_debug) == max_epochs def _make_assertions(epoch, ix, version=''): @@ -297,10 +290,6 @@ def _make_assertions(epoch, ix, version=''): assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) - # check the ckpt file saved on_train_end - if additional_ckpt_path: - _make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1') - @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int):