Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792))


- Added the `ModelCheckpoint(save_on_train_epoch_end)` to choose when to run the saving logic ([#8389](https://github.com/PyTorchLightning/pytorch-lightning/pull/8389))


- Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102))


Expand Down Expand Up @@ -187,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))


- `ModelCheckpoint` now runs at the end of the training epoch by default ([#8389](https://github.com/PyTorchLightning/pytorch-lightning/pull/8389))


- `EarlyStopping` now runs at the end of the training epoch by default ([#8286](https://github.com/PyTorchLightning/pytorch-lightning/pull/8286))


Expand Down
23 changes: 22 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class ModelCheckpoint(Callback):
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation.
period: Interval (number of epochs) between checkpoints.

.. warning::
Expand Down Expand Up @@ -213,6 +215,7 @@ def __init__(
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_epochs: Optional[int] = None,
save_on_train_epoch_end: Optional[bool] = None,
period: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
):
Expand All @@ -223,6 +226,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.auto_insert_metric_name = auto_insert_metric_name
self._save_on_train_epoch_end = save_on_train_epoch_end
self._last_global_step_saved = -1
self._last_time_checked: Optional[float] = None
self.current_score = None
Expand Down Expand Up @@ -251,6 +255,10 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
"""
self.__resolve_ckpt_dir(trainer)
self._save_function = trainer.save_checkpoint
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to save checkpoint after
# validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0

def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
self._last_time_checked = time.monotonic()
Expand Down Expand Up @@ -287,10 +295,23 @@ def on_train_batch_end(

self.save_checkpoint(trainer)

def on_train_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None
) -> None:
""" Save a checkpoint at the end of the training epoch. """
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
trainer.train_loop.global_step -= 1
if (
not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end
and self._every_n_epochs > 0 and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
self.save_checkpoint(trainer)
trainer.train_loop.global_step += 1

def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
""" Save a checkpoint at the end of the validation stage. """
if (
self._should_skip_saving_checkpoint(trainer) or self._every_n_epochs < 1
self._should_skip_saving_checkpoint(trainer) or self._save_on_train_epoch_end or self._every_n_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
):
return
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def on_train_start(self, trainer, pl_module):
if self.expected_state:
assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module):
super().on_train_epoch_end(trainer, pl_module)
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy())


Expand Down
25 changes: 17 additions & 8 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,22 +304,29 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
assert not has_pruning if make_pruning_permanent else has_pruning


@pytest.mark.parametrize("on_train_epoch_end", (False, True))
def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog, on_train_epoch_end):
@pytest.mark.parametrize("prune_on_train_epoch_end", (False, True))
@pytest.mark.parametrize("save_on_train_epoch_end", (False, True))
def test_permanent_when_model_is_saved_multiple_times(
tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end
):
"""
When a model is saved multiple times and make_permanent=True, we need to
make sure a copy is pruned and not the trained model if we want to continue
with the same pruning buffers.
"""
if prune_on_train_epoch_end and save_on_train_epoch_end:
pytest.xfail(
"Pruning sets the `grad_fn` of the parameters so we can't save"
" right after as pruning has not been made permanent"
)

class TestPruning(ModelPruning):

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
had_buffers = hasattr(pl_module.layer.mlp_3, "weight_orig")
super().on_save_checkpoint(trainer, pl_module, checkpoint)
if not on_train_epoch_end:
# these checks only work if pruning on `validation_epoch_end`
# because `on_save_checkpoint` is called before `on_train_epoch_end`
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
if had_buffers:
assert hasattr(pl_module.layer.mlp_3, "weight_orig")

model = TestModel()
Expand All @@ -328,9 +335,11 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
parameters_to_prune=[(model.layer.mlp_3, "weight")],
verbose=1,
make_pruning_permanent=True,
prune_on_train_epoch_end=on_train_epoch_end,
prune_on_train_epoch_end=prune_on_train_epoch_end,
)
ckpt_callback = ModelCheckpoint(
monitor="test", save_top_k=2, save_last=True, save_on_train_epoch_end=save_on_train_epoch_end
)
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
with caplog.at_level(INFO):
trainer.fit(model)
Expand Down
4 changes: 3 additions & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,9 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]

ch_type = type(model_checkpoint)
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def on_save_checkpoint(self, checkpoint) -> None:
assert new_results['validation_step.v'].value.device.type == 'cpu'

model = LoggingModel()
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
ckpt = ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=False)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,14 +550,14 @@ def training_step(self, batch, batch_idx):
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device),
dict(name='Callback.on_validation_end', args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='on_validation_end'),
dict(name='train', args=(True, )),
dict(name='on_validation_model_train'),
dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_epoch_end', args=(trainer, model)),
dict(name='on_epoch_end'),
Expand Down Expand Up @@ -662,11 +662,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
model,
[dict(loss=ANY)] * train_batches,
)),
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_epoch_end', args=(trainer, model)),
dict(name='on_epoch_end'),
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='Callback.on_train_end', args=(trainer, model)),
dict(name='on_train_end'),
dict(name='Callback.on_fit_end', args=(trainer, model)),
Expand Down
9 changes: 5 additions & 4 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,8 @@ def on_save_checkpoint(self, checkpoint):
assert model.on_save_checkpoint_called


def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir):
@pytest.mark.parametrize("save_on_train_epoch_end", (False, True))
def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir, save_on_train_epoch_end):
batches = 4
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -671,7 +672,7 @@ def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir):
max_epochs=1,
limit_train_batches=batches,
limit_val_batches=1,
callbacks=[ModelCheckpoint(dirpath=tmpdir)]
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=save_on_train_epoch_end)]
)

class TestModel(BoringModel):
Expand All @@ -693,8 +694,8 @@ def configure_optimizers(self):

def on_save_checkpoint(self, checkpoint):
lr_dict_1 = checkpoint['lr_schedulers'][0]
# since plateau schedulers are updated after saving checkpoint, last_epoch should be 3
assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0
last_epoch = lr_dict_1['last_epoch']
assert last_epoch == batches - (not save_on_train_epoch_end) # last epoch starts at 0

lr_dict_2 = checkpoint['lr_schedulers'][1]
assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1
Expand Down