Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f26c9ab
PoC
carmocca May 26, 2021
1806272
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 4, 2021
e2acb78
Update code to new loops
carmocca Jul 4, 2021
7b348db
Pass through function
carmocca Jul 4, 2021
db2a6e5
Update after loop refactor
carmocca Jul 4, 2021
fdd4a56
Merge branch 'refactor/remove-check-ckpt-callback' of https://github.…
carmocca Jul 4, 2021
cbc1136
Fix test
carmocca Jul 4, 2021
56e9d89
Fix test
carmocca Jul 4, 2021
bbac98b
Fix tests
carmocca Jul 4, 2021
a0afd13
Fix test
carmocca Jul 4, 2021
5241864
Fix test
carmocca Jul 4, 2021
9f5d886
Remove debug statement
carmocca Jul 5, 2021
45156ee
Fix test
carmocca Jul 5, 2021
bf3c483
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 5, 2021
3da369f
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 7, 2021
76c6be7
Docs and deprecation
carmocca Jul 7, 2021
f9ee8b8
fix test
carmocca Jul 7, 2021
6e1ecf6
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 12, 2021
420a7cd
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 13, 2021
15a8575
Docs
carmocca Jul 13, 2021
e14a80d
Update pytorch_lightning/callbacks/model_checkpoint.py
carmocca Jul 14, 2021
6a0f13c
Parametrize with save last
carmocca Jul 14, 2021
206eefc
Fix ddp test
carmocca Jul 14, 2021
2380228
Fix pre-commit
carmocca Jul 14, 2021
d613729
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 14, 2021
e830627
Merge branch 'master' into refactor/remove-check-ckpt-callback
awaelchli Jul 15, 2021
7b104ee
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 15, 2021
b709a8f
Avoid file not found
carmocca Jul 16, 2021
5fcd3d7
Debug
carmocca Jul 18, 2021
8d978cc
Increase SHM size
carmocca Jul 18, 2021
d9118c5
Debug
carmocca Jul 19, 2021
b3748c4
Refactor MNIST imports
carmocca Jul 19, 2021
45b0d51
Undo debugging
carmocca Jul 19, 2021
bdae378
Prints
carmocca Jul 19, 2021
4df2ac2
Revert "Avoid file not found"
carmocca Jul 19, 2021
e41de44
Merge branch 'ci/debug-deepspeed-nccl-error' into refactor/remove-che…
carmocca Jul 19, 2021
6ad0a5e
Merge branch 'master' into refactor/remove-check-ckpt-callback
carmocca Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,23 @@ def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModul
return
self.save_checkpoint(trainer)

def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""
Save a checkpoint when training stops.

This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during
training/validation steps or end of epochs are not guaranteed to be available at this stage.
"""
if self._should_skip_saving_checkpoint(trainer) or not self.save_last:
return
if self.verbose:
rank_zero_info("Saving latest checkpoint...")
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
trainer.train_loop.global_step -= 1
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.train_loop.global_step += 1

def on_save_checkpoint(
self,
trainer: 'pl.Trainer',
Expand Down
30 changes: 0 additions & 30 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_info

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -227,14 +226,6 @@ def advance(self) -> None:
self.global_step += 1

def on_advance_end(self) -> None:
"""Updates the LR schedulers and does some internal bookkeeping"""
if self.epoch_loop.batches_seen != 0:
did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip
if did_train_only:
self.global_step -= 1
self._check_checkpoint_callback(True)
self.global_step += 1

self.epoch_progress.increment_completed()

def on_run_end(self) -> None:
Expand All @@ -245,13 +236,6 @@ def on_run_end(self) -> None:
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
self.current_epoch -= 1

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.epoch_loop.global_step -= 1
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
self._check_checkpoint_callback(should_update=True, is_last=True)
self.epoch_loop.global_step += 1

# hook
self.trainer.call_hook("on_train_end")

Expand All @@ -271,19 +255,5 @@ def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated"""
return self.epoch_loop.batch_loop.should_accumulate()

def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
"""Checks if checkpointing needs to be done"""
# TODO: bake this logic into the ModelCheckpoint callback
if should_update:
callbacks = self.trainer.checkpoint_callbacks

if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.lightning_module

for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def teardown(self) -> None:
self.epoch_loop.teardown()
15 changes: 9 additions & 6 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_mc_called(tmpdir):
@mock.patch('torch.save')
@pytest.mark.parametrize(
['epochs', 'val_check_interval', 'expected'],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 6)],
)
def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int):

Expand All @@ -74,9 +74,10 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
(1, 1, 1.0, 1),
(2, 2, 1.0, 2),
(2, 1, 0.25, 4),
(2, 2, 0.3, 7),
(2, 2, 0.3, 6),
])
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int):
@pytest.mark.parametrize("save_last", (False, True))
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool):

class TestModel(BoringModel):

Expand All @@ -94,15 +95,17 @@ def training_step(self, batch, batch_idx):

model = TestModel()
trainer = Trainer(
callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k)],
callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k, save_last=save_last)],
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval
)
trainer.fit(model)

# make sure types are correct
if save_last:
# last epochs are saved every step (so double the save calls) and once `on_train_end`
expected = expected * 2 + 1
assert save_mock.call_count == expected


Expand All @@ -115,7 +118,7 @@ def test_top_k_ddp_0(save_mock, tmpdir):
@mock.patch('torch.save')
@RunIf(special=True, min_gpus=2)
def test_top_k_ddp_1(save_mock, tmpdir):
_top_k_ddp(save_mock, tmpdir, k=2, epochs=2, val_check_interval=0.3, expected=5)
_top_k_ddp(save_mock, tmpdir, k=2, epochs=2, val_check_interval=0.3, expected=4)


def _top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected):
Expand Down
13 changes: 1 addition & 12 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,7 @@ def configure_optimizers(self):
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates

# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
additional_ckpt, additional_ckpt_path = False, None
if not epoch_aligned:
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
additional_ckpt = True

assert len(ckpt_files) == len(model.scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs
assert len(lr_scheduler_debug) == max_epochs

def _make_assertions(epoch, ix, version=''):
Expand Down Expand Up @@ -297,10 +290,6 @@ def _make_assertions(epoch, ix, version=''):
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)

# check the ckpt file saved on_train_end
if additional_ckpt_path:
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int):
Expand Down