From 3043259c8135e8ee12500d0c1b8b7acb22ab6e82 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 01:49:05 +0200 Subject: [PATCH 1/5] Add `ModelCheckpoint(save_on_train_epoch_end)` --- .../callbacks/model_checkpoint.py | 25 ++++++++++++++++++- tests/callbacks/test_early_stopping.py | 4 +-- tests/callbacks/test_pruning.py | 25 +++++++++++++------ tests/checkpointing/test_model_checkpoint.py | 4 ++- tests/core/test_metric_result_integration.py | 2 +- tests/models/test_hooks.py | 10 ++++---- tests/trainer/optimization/test_optimizers.py | 9 ++++--- 7 files changed, 57 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4e006d81ce3b9..072306bd9a3f7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -122,6 +122,8 @@ class ModelCheckpoint(Callback): ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. + save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. + If this is ``False``, then the check runs at the end of the validation. period: Interval (number of epochs) between checkpoints. .. warning:: @@ -213,6 +215,7 @@ def __init__( every_n_train_steps: Optional[int] = None, train_time_interval: Optional[timedelta] = None, every_n_epochs: Optional[int] = None, + save_on_train_epoch_end: Optional[bool] = None, period: Optional[int] = None, every_n_val_epochs: Optional[int] = None, ): @@ -223,6 +226,7 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name + self._save_on_train_epoch_end = save_on_train_epoch_end self._last_global_step_saved = -1 self._last_time_checked: Optional[float] = None self.current_score = None @@ -251,6 +255,10 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn """ self.__resolve_ckpt_dir(trainer) self._save_function = trainer.save_checkpoint + if self._save_on_train_epoch_end is None: + # if the user runs validation before multiple times per training epoch, we try to save checkpoint after + # validation instead of on train epoch end + self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._last_time_checked = time.monotonic() @@ -287,10 +295,25 @@ def on_train_batch_end( self.save_checkpoint(trainer) + def on_train_epoch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None + ) -> None: + """ Save a checkpoint at the end of the training epoch. """ + # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates + trainer.train_loop.global_step -= 1 + if ( + self._should_skip_saving_checkpoint(trainer) or not self._save_on_train_epoch_end + or self._every_n_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_epochs != 0 + ): + trainer.train_loop.global_step += 1 + return + self.save_checkpoint(trainer) + trainer.train_loop.global_step += 1 + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ Save a checkpoint at the end of the validation stage. """ if ( - self._should_skip_saving_checkpoint(trainer) or self._every_n_epochs < 1 + self._should_skip_saving_checkpoint(trainer) or self._save_on_train_epoch_end or self._every_n_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_epochs != 0 ): return diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 56252208a6f13..1582a8ed90c91 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -45,8 +45,8 @@ def on_train_start(self, trainer, pl_module): if self.expected_state: assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state - def on_validation_end(self, trainer, pl_module): - super().on_validation_end(trainer, pl_module) + def on_train_epoch_end(self, trainer, pl_module): + super().on_train_epoch_end(trainer, pl_module) self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy()) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 1a5ddad64106e..2dec29e819e4b 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -304,22 +304,29 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool assert not has_pruning if make_pruning_permanent else has_pruning -@pytest.mark.parametrize("on_train_epoch_end", (False, True)) -def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog, on_train_epoch_end): +@pytest.mark.parametrize("prune_on_train_epoch_end", (False, True)) +@pytest.mark.parametrize("save_on_train_epoch_end", (False, True)) +def test_permanent_when_model_is_saved_multiple_times( + tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end +): """ When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not the trained model if we want to continue with the same pruning buffers. """ + if prune_on_train_epoch_end and save_on_train_epoch_end: + pytest.xfail( + "Pruning sets the `grad_fn` of the parameters so we can't save" + " right after as pruning has not been made permanent" + ) class TestPruning(ModelPruning): def on_save_checkpoint(self, trainer, pl_module, checkpoint): + had_buffers = hasattr(pl_module.layer.mlp_3, "weight_orig") super().on_save_checkpoint(trainer, pl_module, checkpoint) - if not on_train_epoch_end: - # these checks only work if pruning on `validation_epoch_end` - # because `on_save_checkpoint` is called before `on_train_epoch_end` - assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] + assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] + if had_buffers: assert hasattr(pl_module.layer.mlp_3, "weight_orig") model = TestModel() @@ -328,9 +335,11 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): parameters_to_prune=[(model.layer.mlp_3, "weight")], verbose=1, make_pruning_permanent=True, - prune_on_train_epoch_end=on_train_epoch_end, + prune_on_train_epoch_end=prune_on_train_epoch_end, + ) + ckpt_callback = ModelCheckpoint( + monitor="test", save_top_k=2, save_last=True, save_on_train_epoch_end=save_on_train_epoch_end ) - ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True) trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0) with caplog.at_level(INFO): trainer.fit(model) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a9e8f0578557a..173059ad33bcd 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -916,7 +916,9 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) - assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) + + assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] + assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] ch_type = type(model_checkpoint) assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 7471914886a27..86cfa35746cda 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -329,7 +329,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert new_results['validation_step.v'].value.device.type == 'cpu' model = LoggingModel() - ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) + ckpt = ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=False) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d89fc090c401f..7df871aaf5fd2 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -550,14 +550,14 @@ def training_step(self, batch, batch_idx): dict(name='on_validation_start'), *model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device), dict(name='Callback.on_validation_end', args=(trainer, model)), - # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end` - dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), - dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='on_validation_end'), dict(name='train', args=(True, )), dict(name='on_validation_model_train'), dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)), + # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end` + dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), + dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_epoch_end', args=(trainer, model)), dict(name='on_epoch_end'), @@ -662,11 +662,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): model, [dict(loss=ANY)] * train_batches, )), + dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), + dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_epoch_end', args=(trainer, model)), dict(name='on_epoch_end'), - dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), - dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='Callback.on_train_end', args=(trainer, model)), dict(name='on_train_end'), dict(name='Callback.on_fit_end', args=(trainer, model)), diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 6165aa132153b..faf5434d6ba5a 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -662,7 +662,8 @@ def on_save_checkpoint(self, checkpoint): assert model.on_save_checkpoint_called -def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir): +@pytest.mark.parametrize("save_on_train_epoch_end", (False, True)) +def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir, save_on_train_epoch_end): batches = 4 trainer = Trainer( default_root_dir=tmpdir, @@ -671,7 +672,7 @@ def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir): max_epochs=1, limit_train_batches=batches, limit_val_batches=1, - callbacks=[ModelCheckpoint(dirpath=tmpdir)] + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=save_on_train_epoch_end)] ) class TestModel(BoringModel): @@ -693,8 +694,8 @@ def configure_optimizers(self): def on_save_checkpoint(self, checkpoint): lr_dict_1 = checkpoint['lr_schedulers'][0] - # since plateau schedulers are updated after saving checkpoint, last_epoch should be 3 - assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0 + last_epoch = lr_dict_1['last_epoch'] + assert last_epoch == batches - (not save_on_train_epoch_end) # last epoch starts at 0 lr_dict_2 = checkpoint['lr_schedulers'][1] assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1 From 5fe6ad8ef5f8beaa58a7ee9bfd714142c4b67bcb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 02:03:30 +0200 Subject: [PATCH 2/5] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b35c08bdb198..353ddb513533e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -152,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792)) +- Added the `ModelCheckpoint(save_on_train_epoch_end)` to choose when to run the saving logic ([#8389](https://github.com/PyTorchLightning/pytorch-lightning/pull/8389)) + + - Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102)) From 38d3089185c6197debc34abdcd9d38dccd72d7ed Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 02:09:11 +0200 Subject: [PATCH 3/5] Invert if --- pytorch_lightning/callbacks/model_checkpoint.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 072306bd9a3f7..4ce7367637e0c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -302,12 +302,10 @@ def on_train_epoch_end( # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates trainer.train_loop.global_step -= 1 if ( - self._should_skip_saving_checkpoint(trainer) or not self._save_on_train_epoch_end - or self._every_n_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_epochs != 0 + not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end + and self._every_n_epochs > 0 and (trainer.current_epoch + 1) % self._every_n_epochs == 0 ): - trainer.train_loop.global_step += 1 - return - self.save_checkpoint(trainer) + self.save_checkpoint(trainer) trainer.train_loop.global_step += 1 def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: From f04748289be4cef9b01bfcf9b4ac9022cbce4dea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 13 Jul 2021 16:13:19 +0200 Subject: [PATCH 4/5] Update pytorch_lightning/callbacks/model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4ce7367637e0c..2186285fb67a4 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -256,7 +256,7 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn self.__resolve_ckpt_dir(trainer) self._save_function = trainer.save_checkpoint if self._save_on_train_epoch_end is None: - # if the user runs validation before multiple times per training epoch, we try to save checkpoint after + # if the user runs validation multiple times per training epoch, we try to save checkpoint after # validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 From 277525ad709b999a0528293c6de0b9ffd2718368 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 16:15:23 +0200 Subject: [PATCH 5/5] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 353ddb513533e..e4d9261abb71e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -185,6 +185,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357)) +- `ModelCheckpoint` now runs at the end of the training epoch by default ([#8389](https://github.com/PyTorchLightning/pytorch-lightning/pull/8389)) + + - `EarlyStopping` now runs at the end of the training epoch by default ([#8286](https://github.com/PyTorchLightning/pytorch-lightning/pull/8286))