diff --git a/CHANGELOG.md b/CHANGELOG.md index a0d229887b283..a37180ba305e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -236,9 +236,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) - - Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) +- Fixed model checkpointing at end of training ([#6671](https://github.com/PyTorchLightning/pytorch-lightning/pull/6671)) + - Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a1a44fd70b139..7b0bc860192bb 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -111,6 +111,9 @@ class ModelCheckpoint(Callback): This argument has been deprecated in v1.3 and will be removed in v1.5. Use ``every_n_val_epochs`` instead. + trigger_on_train_end: Whether to trigger the save_checkpoint at the end of training. + By default, it is turned off. If it is turned on, the model will be saved to file `last.ckpt`. + Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -186,6 +189,7 @@ def __init__( every_n_train_steps: Optional[int] = None, every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, + trigger_on_train_end: bool = False, ): super().__init__() self.monitor = monitor @@ -205,7 +209,7 @@ def __init__( self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) - self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period, trigger_on_train_end) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -239,6 +243,22 @@ def on_validation_end(self, trainer, pl_module) -> None: return self.save_checkpoint(trainer) + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + if not self._trigger_on_train_end: + return + # as we advance one step at end of training, we use global_step - 1 + # to avoid saving duplicates + trainer.global_step -= 1 + if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained): + if self.save_last and self.verbose: + rank_zero_info("Saving last checkpoint...") + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, @@ -286,6 +306,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): def _should_skip_saving_checkpoint(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState + return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.state != TrainerState.FITTING # don't save anything during non-fit @@ -357,7 +378,11 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] def __init_triggers( - self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + self, + every_n_train_steps: Optional[int], + every_n_val_epochs: Optional[int], + period: Optional[int], + trigger_on_train_end: bool, ) -> None: # Default to running once after each validation epoch if neither @@ -379,6 +404,7 @@ def __init_triggers( self._every_n_val_epochs = period self._period = self._every_n_val_epochs + self._trigger_on_train_end = trigger_on_train_end @property def period(self) -> Optional[int]: @@ -585,11 +611,10 @@ def _add_backward_monitor_support(self, trainer): def _validate_monitor_key(self, trainer): metrics = trainer.logger_connector.callback_metrics - # validate metric if self.monitor is not None and not self._is_valid_monitor_key(metrics): m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics " f" {list(metrics.keys())}. " f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?" ) @@ -618,6 +643,7 @@ def _monitor_candidates(self, trainer): return monitor_candidates def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if not self.save_last: return diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d2749733812b3..589c48526b87d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -26,10 +26,9 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing -from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.finite_checks import detect_nan_parameters +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warnings import WarningCache @@ -112,12 +111,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -141,9 +134,6 @@ def check_checkpoint_callback(self, should_update, is_last=False): if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - model = self.trainer.lightning_module for cb in callbacks: diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 7926bc46dd290..7f2c8d19984f0 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -50,7 +50,7 @@ def test_mc_called(tmpdir): @mock.patch('torch.save') @pytest.mark.parametrize( ['epochs', 'val_check_interval', 'expected'], - [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)], + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 6)], ) def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int): @@ -73,7 +73,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter (1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), - (2, 2, 0.3, 7), + (2, 2, 0.3, 6), ]) def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f58ff768759e8..ec71c57e30389 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -452,6 +452,7 @@ def test_model_checkpoint_file_extension(tmpdir): dirpath=tmpdir, save_top_k=1, save_last=True, + trigger_on_train_end=True, ) trainer = Trainer( default_root_dir=tmpdir, @@ -460,8 +461,7 @@ def test_model_checkpoint_file_extension(tmpdir): logger=False, ) trainer.fit(model) - - expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + expected = ['last.tpkc'] assert set(expected) == set(os.listdir(tmpdir)) @@ -593,10 +593,19 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): @pytest.mark.parametrize("period", list(range(4))) -def test_model_checkpoint_period(tmpdir, period: int): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +@pytest.mark.parametrize('save_last', [False, True]) +def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool, save_last: bool): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + save_last=save_last, + period=period, + trigger_on_train_end=trigger_on_train_end, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -608,16 +617,25 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + expected = ([f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] if period > 0 else []) + if save_last and (period > 0 or trigger_on_train_end): + expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +@pytest.mark.parametrize('save_last', [False, True]) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger_on_train_end: bool, save_last: bool): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + save_last=save_last, + every_n_val_epochs=every_n_val_epochs, + trigger_on_train_end=trigger_on_train_end, ) trainer = Trainer( default_root_dir=tmpdir, @@ -630,13 +648,20 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + expected = ([f"epoch={e}.ckpt" for e in range(epochs) + if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else []) + + if save_last and (every_n_val_epochs > 0 or trigger_on_train_end): + expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +@pytest.mark.parametrize('save_last', [False, True]) +def test_model_checkpoint_every_n_val_epochs_and_period( + tmpdir, every_n_val_epochs, trigger_on_train_end: bool, save_last: bool +): """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 @@ -644,8 +669,10 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc dirpath=tmpdir, filename='{epoch}', save_top_k=-1, + save_last=save_last, every_n_val_epochs=(2 * every_n_val_epochs), - period=every_n_val_epochs + period=every_n_val_epochs, + trigger_on_train_end=trigger_on_train_end, ) trainer = Trainer( default_root_dir=tmpdir, @@ -658,8 +685,10 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + expected = ([f"epoch={e}.ckpt" for e in range(epochs) + if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else []) + if save_last and (every_n_val_epochs > 0 or trigger_on_train_end): + expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @@ -794,18 +823,60 @@ def test_default_checkpoint_behavior(tmpdir): assert ckpts[0] == 'epoch=2-step=14.ckpt' +@pytest.mark.parametrize('save_last', [False, True]) +def test_ckpt_on_train_end_with_invalid_monitor(tmpdir, save_last: bool): + """ Tests that the checkpoints are saved at end of training with invalid monitor.""" + + model = LogInTwoMethods() + model_cpt = ModelCheckpoint( + filename="{epoch}", + dirpath=tmpdir, + every_n_val_epochs=2, + monitor="invalid", # monitor is invalid, save_last is not set + save_last=save_last, + trigger_on_train_end=True, + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + progress_bar_refresh_rate=0, + callbacks=[model_cpt], + logger=False, + ) + trainer.fit(model) + expected = ['last.ckpt'] if save_last else [] + assert set(expected) == set(os.listdir(tmpdir)) + + @pytest.mark.parametrize('max_epochs', [1, 2]) +@pytest.mark.parametrize('every_n_val_epochs', [2, 3]) @pytest.mark.parametrize('should_validate', [True, False]) @pytest.mark.parametrize('save_last', [True, False]) @pytest.mark.parametrize('verbose', [True, False]) +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) def test_model_checkpoint_save_last_warning( - tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool + tmpdir, + caplog, + max_epochs: int, + every_n_val_epochs: int, + should_validate: bool, + save_last: bool, + verbose: bool, + trigger_on_train_end: bool, ): - """Tests 'Saving latest checkpoint...' log""" + """Tests 'Saving last checkpoint...' log""" model = LogInTwoMethods() if not should_validate: model.validation_step = None - ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose) + ckpt = ModelCheckpoint( + monitor='early_stop_on', + dirpath=tmpdir, + every_n_val_epochs=every_n_val_epochs, + save_top_k=0, + save_last=save_last, + verbose=verbose, + trigger_on_train_end=trigger_on_train_end, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[ckpt], @@ -813,7 +884,10 @@ def test_model_checkpoint_save_last_warning( ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + expected = False + if save_last and verbose and trigger_on_train_end: + expected = (max_epochs % every_n_val_epochs != 0) + assert caplog.messages.count('Saving last checkpoint...') == expected def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 5d4d36db8aa4d..9e96fd6e44985 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -55,7 +55,7 @@ def __init__(self, b1=0.5, b2=0.999): assert len(yaml_params.keys()) == 2 # verify artifacts - assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1 + assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 0 # verify tb logs event_acc = EventAccumulator(folder_path) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..1ae96f8e4bd9f 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -57,7 +57,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): callback0 = StatefulCallback0() callback1 = StatefulCallback1() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True, trigger_on_train_end=True) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -67,7 +67,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): ) trainer.fit(model) - ckpt = torch.load(str(tmpdir / "all_states.ckpt")) + ckpt = torch.load(str(tmpdir / "last.ckpt")) state0 = ckpt["callbacks"][type(callback0)] state1 = ckpt["callbacks"][type(callback1)] assert "content0" in state0 and state0["content0"] == 0