From 42b1f4e9137d6682803439f3e21a4faacd4980bd Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 20 Nov 2020 18:11:06 +0300 Subject: [PATCH 01/18] change validation_end to epoch_end --- pytorch_lightning/callbacks/model_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index c2a8c3a6ff859..995d16ac52486 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -178,9 +178,9 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer, pl_module) self.save_function = trainer.save_checkpoint - def on_validation_end(self, trainer, pl_module): + def on_epoch_end(self, trainer, pl_module): """ - checkpoints can be saved at the end of the val loop + checkpoints can be saved at the end of the epoch loop """ self.save_checkpoint(trainer, pl_module) @@ -349,7 +349,7 @@ def check_monitor_top_k(self, current) -> bool: if not isinstance(current, torch.Tensor): rank_zero_warn( f"{current} is supposed to be a `torch.Tensor`. Saving checkpoint may not work correctly." - f" HINT: check the value of {self.monitor} in your validation loop", + f" HINT: check the value of {self.monitor} in your validation or training loop", RuntimeWarning, ) current = torch.tensor(current) From 7f12291dfe3edf3a651af3cc64b4a1b2ea13b47a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 22 Nov 2020 10:36:45 +0300 Subject: [PATCH 02/18] Revert "change validation_end to epoch_end" This reverts commit 42b1f4e9 --- pytorch_lightning/callbacks/model_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 995d16ac52486..c2a8c3a6ff859 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -178,9 +178,9 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer, pl_module) self.save_function = trainer.save_checkpoint - def on_epoch_end(self, trainer, pl_module): + def on_validation_end(self, trainer, pl_module): """ - checkpoints can be saved at the end of the epoch loop + checkpoints can be saved at the end of the val loop """ self.save_checkpoint(trainer, pl_module) @@ -349,7 +349,7 @@ def check_monitor_top_k(self, current) -> bool: if not isinstance(current, torch.Tensor): rank_zero_warn( f"{current} is supposed to be a `torch.Tensor`. Saving checkpoint may not work correctly." - f" HINT: check the value of {self.monitor} in your validation or training loop", + f" HINT: check the value of {self.monitor} in your validation loop", RuntimeWarning, ) current = torch.tensor(current) From e0aacea255ada357c98882eeb2c5c7cda1540b56 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 10:18:39 +0300 Subject: [PATCH 03/18] add a failing test with boring model --- tests/checkpointing/test_model_checkpoint.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 33bc19a894d8f..7df80a7725de7 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -997,3 +997,23 @@ def __init__(self, hparams): else: # make sure it's not AttributeDict assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type + + +def test_model_checkpoint_no_val_loader_invocation(tmpdir): + """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" + class NoValBoringModel(LogInTwoMethods): + def val_dataloader(self): + return None + + model = NoValBoringModel() + + num_epochs = 4 + model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_epochs=num_epochs, + gpus=0, + ) + result = trainer.fit(model) + assert 1 == result From c62a943a167216de39fba378291928043e3c00b0 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 11:22:12 +0300 Subject: [PATCH 04/18] start porting logic inside ModelCheckpoint --- .../callbacks/model_checkpoint.py | 14 +++++++++- pytorch_lightning/trainer/training_loop.py | 27 ++++++++++--------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d41928cd55aea..b95911961a5d3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -194,6 +194,18 @@ def on_validation_end(self, trainer, pl_module): """ self.save_checkpoint(trainer, pl_module) + def on_epoch_end(self, trainer, pl_module): + """ + checkpoints can be saved at the end of the val loop + """ + self.save_checkpoint(trainer, pl_module) + + def on_train_end(self, trainer, pl_module): + """ + checkpoints can be saved at the end of the val loop + """ + self.save_checkpoint(trainer, pl_module) + def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { "monitor": self.monitor, @@ -517,11 +529,11 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) should_save_last = self.monitor is None or self.save_last if not should_save_last: return - last_filepath = filepath # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: + rank_zero_info("Saving latest checkpoint...") last_filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9a4f324033d39..e5da97a75ff24 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -185,11 +185,12 @@ def on_train_end(self): self._teardown_already_run = True + # TODO: clean up # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_save=True, is_last=True) - self.trainer.global_step += 1 + # self.trainer.global_step -= 1 + # self.check_checkpoint_callback(should_save=True, is_last=True) + # self.trainer.global_step += 1 # hook self.trainer.call_hook("on_train_end") @@ -210,15 +211,15 @@ def on_train_end(self): model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() - - def check_checkpoint_callback(self, should_save, is_last=False): - # TODO bake this logic into the checkpoint callback - if should_save and self.trainer.checkpoint_connector.has_trained: - checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - if is_last and any(c.save_last for c in checkpoint_callbacks): - rank_zero_info("Saving latest checkpoint...") - model = self.trainer.get_model() - [cb.on_validation_end(self.trainer, model) for cb in checkpoint_callbacks] + # # TODO: clean up + # def check_checkpoint_callback(self, should_save, is_last=False): + # # TODO bake this logic into the checkpoint callback + # if should_save and self.trainer.checkpoint_connector.has_trained: + # checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + # if is_last and any(c.save_last for c in checkpoint_callbacks): + # rank_zero_info("Saving latest checkpoint...") + # model = self.trainer.get_model() + # [cb.on_validation_end(self.trainer, model) for cb in checkpoint_callbacks] def on_train_epoch_start(self, epoch): @@ -619,7 +620,7 @@ def run_training_epoch(self): ) # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) + # self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) # increment the global step once # progress global step according to grads progress From 34601980917bb2dcf0ab7b5efc360d63e6a2dfd3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 12:59:23 +0300 Subject: [PATCH 05/18] refactor code, some checks are still failing --- .../callbacks/model_checkpoint.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b95911961a5d3..2385f42bb1754 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -196,15 +196,18 @@ def on_validation_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module): """ - checkpoints can be saved at the end of the val loop + checkpoints can be saved at the end of the train loop """ self.save_checkpoint(trainer, pl_module) def on_train_end(self, trainer, pl_module): """ - checkpoints can be saved at the end of the val loop + checkpoints can be saved at the end of the epoch loop """ + trainer.global_step -= 1 + # self.check_checkpoint_callback(should_save=True, is_last=True) self.save_checkpoint(trainer, pl_module) + trainer.global_step += 1 def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { @@ -218,6 +221,18 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] + def should_save(self, trainer): + epoch = trainer.current_epoch + global_step = trainer.global_step + should_save = not ( + self.save_top_k == 0 # no models are saved + or self.period < 1 # no models are saved + or (epoch + 1) % self.period # skip epoch + or trainer.running_sanity_check # don't save anything during sanity check + or self.last_global_step_saved == global_step # already saved at the last step + ) + return should_save + def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. @@ -227,13 +242,7 @@ def save_checkpoint(self, trainer, pl_module): epoch = trainer.current_epoch global_step = trainer.global_step - if ( - self.save_top_k == 0 # no models are saved - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or trainer.running_sanity_check # don't save anything during sanity check - or self.last_global_step_saved == global_step # already saved at the last step - ): + if not self.should_save(trainer): return self._add_backward_monitor_support(trainer) @@ -497,11 +506,17 @@ def _add_backward_monitor_support(self, trainer): if self.save_top_k is None and self.monitor is not None: self.save_top_k = 1 + def _valid_monitor_key(self, trainer): + metrics = trainer.logger_connector.callback_metrics + + # validate metric + return self.monitor is None or self._is_valid_monitor_key(metrics) + def _validate_monitor_key(self, trainer): metrics = trainer.logger_connector.callback_metrics # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics): + if not self._valid_monitor_key(trainer): m = ( f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" f" {list(metrics.keys())}. " From 05a121ecb3a130d7dbd9622fdbfaeaa3b333514a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 13:06:57 +0300 Subject: [PATCH 06/18] has trained condition --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2385f42bb1754..4501f51e5b1d2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -230,7 +230,7 @@ def should_save(self, trainer): or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check or self.last_global_step_saved == global_step # already saved at the last step - ) + ) and trainer.checkpoint_connector.has_trained return should_save def save_checkpoint(self, trainer, pl_module): From ddf6b629dc144d51440df1d332e82871a7a820fc Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 13:39:56 +0300 Subject: [PATCH 07/18] save last flag, still strange erros --- pytorch_lightning/callbacks/model_checkpoint.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4501f51e5b1d2..d65057b069709 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -205,8 +205,7 @@ def on_train_end(self, trainer, pl_module): checkpoints can be saved at the end of the epoch loop """ trainer.global_step -= 1 - # self.check_checkpoint_callback(should_save=True, is_last=True) - self.save_checkpoint(trainer, pl_module) + self.save_checkpoint(trainer, pl_module, is_last=True) trainer.global_step += 1 def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: @@ -233,7 +232,7 @@ def should_save(self, trainer): ) and trainer.checkpoint_connector.has_trained return should_save - def save_checkpoint(self, trainer, pl_module): + def save_checkpoint(self, trainer, pl_module, is_last=True): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -242,7 +241,7 @@ def save_checkpoint(self, trainer, pl_module): epoch = trainer.current_epoch global_step = trainer.global_step - if not self.should_save(trainer): + if not (self.should_save(trainer) or is_last): return self._add_backward_monitor_support(trainer) From 51db066b9718f4d8ad0c78dcc2c91f09be2344b8 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 13:54:07 +0300 Subject: [PATCH 08/18] fix test --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d65057b069709..43207a7cc83ef 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -232,7 +232,7 @@ def should_save(self, trainer): ) and trainer.checkpoint_connector.has_trained return should_save - def save_checkpoint(self, trainer, pl_module, is_last=True): + def save_checkpoint(self, trainer, pl_module, is_last=False): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7df80a7725de7..45010e071e55b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -543,14 +543,21 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v model = LogInTwoMethods() if not should_validate: model.validation_step = None + model_checkpoint = ModelCheckpoint( + monitor='early_stop_on', dirpath=tmpdir, + save_top_k=0, save_last=save_last + ) trainer = Trainer( default_root_dir=tmpdir, - callbacks=[ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, - save_top_k=0, save_last=save_last)], + callbacks=[model_checkpoint], max_epochs=max_epochs, ) trainer.fit(model) assert caplog.messages.count('Saving latest checkpoint...') == save_last + path_last = str(tmpdir / "last.ckpt") + if save_last: + assert path_last == model_checkpoint.last_model_path + assert os.path.isfile(path_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): From b37e51c28061ccd253c37fda52556ec24683c884 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 14:06:14 +0300 Subject: [PATCH 09/18] clean up the code --- .../callbacks/model_checkpoint.py | 18 +++++++++++++----- pytorch_lightning/trainer/training_loop.py | 19 ------------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 43207a7cc83ef..bd48b9d0cc052 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -220,17 +220,25 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - def should_save(self, trainer): + def should_save(self, trainer, is_last=False): epoch = trainer.current_epoch global_step = trainer.global_step should_save = not ( + # negative conditions self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check - or self.last_global_step_saved == global_step # already saved at the last step - ) and trainer.checkpoint_connector.has_trained - return should_save + ) or ( + # positive conditions + is_last and self.save_last # user required to save the last model + ) + # already saved at the last step + should_skip = (self.last_global_step_saved == global_step) + # used in this scenario: + # tests.checkpointing.test_model_checkpoint.test_checkpoint_repeated_strategy_extended + has_trained = trainer.checkpoint_connector.has_trained + return should_save and not should_skip and has_trained def save_checkpoint(self, trainer, pl_module, is_last=False): """ @@ -241,7 +249,7 @@ def save_checkpoint(self, trainer, pl_module, is_last=False): epoch = trainer.current_epoch global_step = trainer.global_step - if not (self.should_save(trainer) or is_last): + if not self.should_save(trainer, is_last=is_last): return self._add_backward_monitor_support(trainer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e5da97a75ff24..d7082fa12ead0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -185,13 +185,6 @@ def on_train_end(self): self._teardown_already_run = True - # TODO: clean up - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - # self.trainer.global_step -= 1 - # self.check_checkpoint_callback(should_save=True, is_last=True) - # self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -211,15 +204,6 @@ def on_train_end(self): model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() - # # TODO: clean up - # def check_checkpoint_callback(self, should_save, is_last=False): - # # TODO bake this logic into the checkpoint callback - # if should_save and self.trainer.checkpoint_connector.has_trained: - # checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - # if is_last and any(c.save_last for c in checkpoint_callbacks): - # rank_zero_info("Saving latest checkpoint...") - # model = self.trainer.get_model() - # [cb.on_validation_end(self.trainer, model) for cb in checkpoint_callbacks] def on_train_epoch_start(self, epoch): @@ -619,9 +603,6 @@ def run_training_epoch(self): self.num_optimizers ) - # when no val loop is present or fast-dev-run still need to call checkpoints - # self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() From 653a4c2202264f087e13fe19e4ad94af988e3040 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 14:11:23 +0300 Subject: [PATCH 10/18] reformat should_save --- pytorch_lightning/callbacks/model_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bd48b9d0cc052..8035a589483ad 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -231,10 +231,11 @@ def should_save(self, trainer, is_last=False): or trainer.running_sanity_check # don't save anything during sanity check ) or ( # positive conditions - is_last and self.save_last # user required to save the last model + is_last + and self.save_last # user required to save the last model ) # already saved at the last step - should_skip = (self.last_global_step_saved == global_step) + should_skip = self.last_global_step_saved == global_step # used in this scenario: # tests.checkpointing.test_model_checkpoint.test_checkpoint_repeated_strategy_extended has_trained = trainer.checkpoint_connector.has_trained From 3a4be6d6460b62fd0c5936ecbc8b4de70999db14 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Dec 2020 14:13:57 +0300 Subject: [PATCH 11/18] one more call to save checkpoint if fractional frequency --- tests/checkpointing/test_checkpoint_callback_frequency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 0662cf7677431..6c4847297f28a 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -93,7 +93,7 @@ def test_mc_called(tmpdir): @mock.patch('torch.save') @pytest.mark.parametrize(['epochs', 'val_check_interval', 'expected'], - [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)]) + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 8)]) def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, expected): model = BoringModel() @@ -111,7 +111,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, @mock.patch('torch.save') @pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], - [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 7)]) + [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 8)]) def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected): class TestModel(BoringModel): From 8d01da8236144cdd2f850415bcbdfbb8d007a59b Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Dec 2020 10:57:10 +0300 Subject: [PATCH 12/18] vix W503 --- pytorch_lightning/callbacks/model_checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8035a589483ad..d30ea0c938422 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -225,14 +225,14 @@ def should_save(self, trainer, is_last=False): global_step = trainer.global_step should_save = not ( # negative conditions - self.save_top_k == 0 # no models are saved - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or trainer.running_sanity_check # don't save anything during sanity check + self.save_top_k == 0 or + self.period < 1 or + (epoch + 1) % self.period or + # don't save anything during sanity check + trainer.running_sanity_check ) or ( # positive conditions - is_last - and self.save_last # user required to save the last model + is_last and self.save_last # user required to save the last model ) # already saved at the last step should_skip = self.last_global_step_saved == global_step From 8550b5cde78c4d862fe33bc6421c557e8b6acb26 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Dec 2020 11:17:31 +0300 Subject: [PATCH 13/18] add missing condition to the trainer test --- tests/trainer/test_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 328b2c0a0f859..782f7d5cf08cc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -440,6 +440,8 @@ def mock_save_function(filepath, *args): trainer.current_epoch = i trainer.global_step = i trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} + # after forward-backward `has_trained` is set, this condition is also checked + trainer.checkpoint_connector.has_trained = True checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(tmpdir)) From 92d3b4a9e8eeb47c074ad18fb902719ebaaa95cf Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Dec 2020 11:48:33 +0300 Subject: [PATCH 14/18] a better comment --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 33d26b2400a0c..7a973dce1aadc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -245,8 +245,7 @@ def should_save(self, trainer, is_last=False): ) # already saved at the last step should_skip = self.last_global_step_saved == global_step - # used in this scenario: - # tests.checkpointing.test_model_checkpoint.test_checkpoint_repeated_strategy_extended + # it is true after forward-backward pass has_trained = trainer.checkpoint_connector.has_trained return should_save and not should_skip and has_trained From bb47a220faa1afcc400df0ade0ee93ea21be8fe9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Dec 2020 12:04:32 +0300 Subject: [PATCH 15/18] zero rank save --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7a973dce1aadc..76ac88a8b24d8 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -376,13 +376,13 @@ def _del_model(self, filepath: str): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") + @rank_zero_only def _save_model(self, filepath: str, trainer, pl_module): # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - if trainer.is_global_zero: - self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) + self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the trainer if self.save_function is not None: From b3e4c5aa52dbe313de8867b42f9f671020cdb8a2 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Dec 2020 17:49:33 +0300 Subject: [PATCH 16/18] do not use rank_zero\n _save_model may have side effects in on_save_checkpoint --- pytorch_lightning/callbacks/model_checkpoint.py | 5 ++--- tests/checkpointing/test_model_checkpoint.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 76ac88a8b24d8..bdfc93e2128f5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -255,7 +255,6 @@ def save_checkpoint(self, trainer, pl_module, is_last=False): This method runs on all ranks, it is the responsibility of `self.save_function` to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ - epoch = trainer.current_epoch global_step = trainer.global_step if not self.should_save(trainer, is_last=is_last): @@ -376,13 +375,13 @@ def _del_model(self, filepath: str): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") - @rank_zero_only def _save_model(self, filepath: str, trainer, pl_module): # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) + if trainer.is_global_zero: + self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the trainer if self.save_function is not None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index fa340932f0feb..677f7b04f9b07 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -167,11 +167,11 @@ def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) assert self.best_model_path assert self.best_model_score - assert self.on_save_checkpoint_count == self.expected_count + assert self.on_save_checkpoint_count == self.expected_count, (self.on_save_checkpoint_count, self.expected_count) if trainer.is_global_zero: - assert torch.save.call_count == self.expected_count + assert torch.save.call_count == self.expected_count, (torch.save.call_count, self.expected_count) else: - assert torch.save.call_count == 0 + assert torch.save.call_count == 0, torch.save.call_count @pytest.mark.skipif( From 02254269df34a03f11ab1598dafe05a0d3277ce8 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 21 Dec 2020 10:33:27 +0300 Subject: [PATCH 17/18] black a piece of code --- pytorch_lightning/callbacks/model_checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bdfc93e2128f5..490b98375cba0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -234,14 +234,14 @@ def should_save(self, trainer, is_last=False): global_step = trainer.global_step should_save = not ( # negative conditions - self.save_top_k == 0 or - self.period < 1 or - (epoch + 1) % self.period or - # don't save anything during sanity check - trainer.running_sanity_check + self.save_top_k == 0 + or self.period < 1 + or (epoch + 1) % self.period + or trainer.running_sanity_check ) or ( # positive conditions - is_last and self.save_last # user required to save the last model + is_last + and self.save_last # user required to save the last model ) # already saved at the last step should_skip = self.last_global_step_saved == global_step From 0540544f356f69e4dcf93bb6b37be2dba5536436 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sat, 16 Jan 2021 12:05:44 +0300 Subject: [PATCH 18/18] info about saving latest checkpoint only in on_train_end callback stage --- pytorch_lightning/callbacks/model_checkpoint.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a3c4c4c9649ac..6f2a8219f8921 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -277,7 +277,7 @@ def save_checkpoint(self, trainer, pl_module, is_last=False): self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates) # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, monitor_candidates) + self._save_last_checkpoint(trainer, pl_module, monitor_candidates, is_last=is_last) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -571,14 +571,15 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch}) return ckpt_name_metrics - def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics): + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, is_last=False): should_save_last = self.monitor is None or self.save_last if not should_save_last: return # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: - rank_zero_info("Saving latest checkpoint...") + if is_last: + rank_zero_info("Saving latest checkpoint...") last_filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch,