diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e5c960b3c002b..6f2a8219f8921 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -202,6 +202,20 @@ def on_validation_end(self, trainer, pl_module): """ self.save_checkpoint(trainer, pl_module) + def on_epoch_end(self, trainer, pl_module): + """ + checkpoints can be saved at the end of the train loop + """ + self.save_checkpoint(trainer, pl_module) + + def on_train_end(self, trainer, pl_module): + """ + checkpoints can be saved at the end of the epoch loop + """ + trainer.global_step -= 1 + self.save_checkpoint(trainer, pl_module, is_last=True) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { "monitor": self.monitor, @@ -215,23 +229,36 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - def save_checkpoint(self, trainer, pl_module): + def should_save(self, trainer, is_last=False): + epoch = trainer.current_epoch + global_step = trainer.global_step + should_save = not ( + # negative conditions + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or self.save_top_k == 0 + or self.period < 1 + or (epoch + 1) % self.period + or trainer.running_sanity_check + ) or ( + # positive conditions + is_last + and self.save_last # user required to save the last model + ) + # already saved at the last step + should_skip = self.last_global_step_saved == global_step + # it is true after forward-backward pass + has_trained = trainer.checkpoint_connector.has_trained + return should_save and not should_skip and has_trained + + def save_checkpoint(self, trainer, pl_module, is_last=False): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ - epoch = trainer.current_epoch global_step = trainer.global_step - if ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or self.save_top_k == 0 # no models are saved - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or trainer.running_sanity_check # don't save anything during sanity check - or self.last_global_step_saved == global_step # already saved at the last step - ): + if not self.should_save(trainer, is_last=is_last): return self._add_backward_monitor_support(trainer) @@ -250,7 +277,7 @@ def save_checkpoint(self, trainer, pl_module): self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates) # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, monitor_candidates) + self._save_last_checkpoint(trainer, pl_module, monitor_candidates, is_last=is_last) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -503,11 +530,17 @@ def _add_backward_monitor_support(self, trainer): if self.save_top_k is None and self.monitor is not None: self.save_top_k = 1 + def _valid_monitor_key(self, trainer): + metrics = trainer.logger_connector.callback_metrics + + # validate metric + return self.monitor is None or self._is_valid_monitor_key(metrics) + def _validate_monitor_key(self, trainer): metrics = trainer.logger_connector.callback_metrics # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics): + if not self._valid_monitor_key(trainer): m = ( f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" f" {list(metrics.keys())}. " @@ -538,13 +571,15 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch}) return ckpt_name_metrics - def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics): + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, is_last=False): should_save_last = self.monitor is None or self.save_last if not should_save_last: return # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: + if is_last: + rank_zero_info("Saving latest checkpoint...") last_filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 47e254606af93..1d66242070bd8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -156,12 +156,6 @@ def on_train_end(self): self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_save=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -182,19 +176,6 @@ def on_train_end(self): model.cpu() torch.cuda.empty_cache() - def check_checkpoint_callback(self, should_save, is_last=False): - # TODO bake this logic into the checkpoint callback - if should_save and self.trainer.checkpoint_connector.has_trained: - checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - - if is_last and any(c.save_last for c in checkpoint_callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.get_model() - - for callback in checkpoint_callbacks: - callback.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -606,9 +587,6 @@ def run_training_epoch(self): self.num_optimizers ) - # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index f9686dce159dd..e29e291c02b23 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -48,7 +48,7 @@ def test_mc_called(tmpdir): @mock.patch('torch.save') @pytest.mark.parametrize(['epochs', 'val_check_interval', 'expected'], - [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)]) + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 8)]) def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, expected): model = BoringModel() @@ -66,7 +66,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, @mock.patch('torch.save') @pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], - [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 7)]) + [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 8)]) def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected): class TestModel(BoringModel): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3de26ef1a6fb6..748bf88783a39 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -167,11 +167,11 @@ def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) assert self.best_model_path assert self.best_model_score - assert self.on_save_checkpoint_count == self.expected_count + assert self.on_save_checkpoint_count == self.expected_count, (self.on_save_checkpoint_count, self.expected_count) if trainer.is_global_zero: - assert torch.save.call_count == self.expected_count + assert torch.save.call_count == self.expected_count, (torch.save.call_count, self.expected_count) else: - assert torch.save.call_count == 0 + assert torch.save.call_count == 0, torch.save.call_count @pytest.mark.skipif( @@ -564,14 +564,21 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v model = LogInTwoMethods() if not should_validate: model.validation_step = None + model_checkpoint = ModelCheckpoint( + monitor='early_stop_on', dirpath=tmpdir, + save_top_k=0, save_last=save_last + ) trainer = Trainer( default_root_dir=tmpdir, - callbacks=[ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, - save_top_k=0, save_last=save_last)], + callbacks=[model_checkpoint], max_epochs=max_epochs, ) trainer.fit(model) assert caplog.messages.count('Saving latest checkpoint...') == save_last + path_last = str(tmpdir / "last.ckpt") + if save_last: + assert path_last == model_checkpoint.last_model_path + assert os.path.isfile(path_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): @@ -937,6 +944,26 @@ def __init__(self, hparams): assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type +def test_model_checkpoint_no_val_loader_invocation(tmpdir): + """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" + class NoValBoringModel(LogInTwoMethods): + def val_dataloader(self): + return None + + model = NoValBoringModel() + + num_epochs = 4 + model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_epochs=num_epochs, + gpus=0, + ) + result = trainer.fit(model) + assert 1 == result + + @pytest.mark.parametrize('max_epochs', [3, 4]) @pytest.mark.parametrize( 'save_top_k, expected', diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8b66e7141957e..2aeed0e95e72d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -442,6 +442,8 @@ def mock_save_function(filepath, *args): trainer.current_epoch = i trainer.global_step = i trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} + # after forward-backward `has_trained` is set, this condition is also checked + trainer.checkpoint_connector.has_trained = True checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(tmpdir))