Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def on_validation_end(self, trainer, pl_module):
"""
self.save_checkpoint(trainer, pl_module)

def on_epoch_end(self, trainer, pl_module):
"""
checkpoints can be saved at the end of the train loop
"""
self.save_checkpoint(trainer, pl_module)

def on_train_end(self, trainer, pl_module):
"""
checkpoints can be saved at the end of the epoch loop
"""
trainer.global_step -= 1
self.save_checkpoint(trainer, pl_module, is_last=True)
trainer.global_step += 1

def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
return {
"monitor": self.monitor,
Expand All @@ -215,23 +229,36 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]

def save_checkpoint(self, trainer, pl_module):
def should_save(self, trainer, is_last=False):
epoch = trainer.current_epoch
global_step = trainer.global_step
should_save = not (
# negative conditions
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or self.save_top_k == 0
or self.period < 1
or (epoch + 1) % self.period
or trainer.running_sanity_check
) or (
# positive conditions
is_last
and self.save_last # user required to save the last model
)
# already saved at the last step
should_skip = self.last_global_step_saved == global_step
# it is true after forward-backward pass
has_trained = trainer.checkpoint_connector.has_trained
return should_save and not should_skip and has_trained

def save_checkpoint(self, trainer, pl_module, is_last=False):
"""
Performs the main logic around saving a checkpoint.
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
epoch = trainer.current_epoch
global_step = trainer.global_step

if (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self.last_global_step_saved == global_step # already saved at the last step
):
if not self.should_save(trainer, is_last=is_last):
return

self._add_backward_monitor_support(trainer)
Expand All @@ -250,7 +277,7 @@ def save_checkpoint(self, trainer, pl_module):
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, is_last=is_last)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -503,11 +530,17 @@ def _add_backward_monitor_support(self, trainer):
if self.save_top_k is None and self.monitor is not None:
self.save_top_k = 1

def _valid_monitor_key(self, trainer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe merge together this function and _is_valid_monitor_key?

metrics = trainer.logger_connector.callback_metrics

# validate metric
return self.monitor is None or self._is_valid_monitor_key(metrics)

def _validate_monitor_key(self, trainer):
metrics = trainer.logger_connector.callback_metrics

# validate metric
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
if not self._valid_monitor_key(trainer):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
Expand Down Expand Up @@ -538,13 +571,15 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, is_last=False):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return

# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
if is_last:
rank_zero_info("Saving latest checkpoint...")
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
Expand Down
22 changes: 0 additions & 22 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,6 @@ def on_train_end(self):

self._teardown_already_run = True

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_save=True, is_last=True)
self.trainer.global_step += 1

# hook
self.trainer.call_hook("on_train_end")

Expand All @@ -182,19 +176,6 @@ def on_train_end(self):
model.cpu()
torch.cuda.empty_cache()

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save and self.trainer.checkpoint_connector.has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]

if is_last and any(c.save_last for c in checkpoint_callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.get_model()

for callback in checkpoint_callbacks:
callback.on_validation_end(self.trainer, model)

def on_train_epoch_start(self, epoch):

# update training progress in trainer
Expand Down Expand Up @@ -606,9 +587,6 @@ def run_training_epoch(self):
self.num_optimizers
)

# when no val loop is present or fast-dev-run still need to call checkpoints
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just fix your usecase by adding sum(self.trainer.num_val_batches) == 0 here. working on a PR #5208 fixing more issues there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7, does it make sense to close the issue in #5208, not here? I'm fine with closing the PR if a nicer solution is proposed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest yes since doing a bit of a refactor there to fix more issues. Your use-case is already fixed there. Mind check if it works for you??


# increment the global step once
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_mc_called(tmpdir):

@mock.patch('torch.save')
@pytest.mark.parametrize(['epochs', 'val_check_interval', 'expected'],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)])
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 8)])
def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, expected):

model = BoringModel()
Expand All @@ -66,7 +66,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval,

@mock.patch('torch.save')
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'],
[(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 7)])
[(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 8)])
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected):

class TestModel(BoringModel):
Expand Down
37 changes: 32 additions & 5 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
assert self.best_model_path
assert self.best_model_score
assert self.on_save_checkpoint_count == self.expected_count
assert self.on_save_checkpoint_count == self.expected_count, (self.on_save_checkpoint_count, self.expected_count)
if trainer.is_global_zero:
assert torch.save.call_count == self.expected_count
assert torch.save.call_count == self.expected_count, (torch.save.call_count, self.expected_count)
else:
assert torch.save.call_count == 0
assert torch.save.call_count == 0, torch.save.call_count


@pytest.mark.skipif(
Expand Down Expand Up @@ -564,14 +564,21 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
model = LogInTwoMethods()
if not should_validate:
model.validation_step = None
model_checkpoint = ModelCheckpoint(
monitor='early_stop_on', dirpath=tmpdir,
save_top_k=0, save_last=save_last
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir,
save_top_k=0, save_last=save_last)],
callbacks=[model_checkpoint],
max_epochs=max_epochs,
)
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == save_last
path_last = str(tmpdir / "last.ckpt")
if save_last:
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
Expand Down Expand Up @@ -937,6 +944,26 @@ def __init__(self, hparams):
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


def test_model_checkpoint_no_val_loader_invocation(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
class NoValBoringModel(LogInTwoMethods):
def val_dataloader(self):
return None

model = NoValBoringModel()

num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=num_epochs,
gpus=0,
)
result = trainer.fit(model)
assert 1 == result


@pytest.mark.parametrize('max_epochs', [3, 4])
@pytest.mark.parametrize(
'save_top_k, expected',
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def mock_save_function(filepath, *args):
trainer.current_epoch = i
trainer.global_step = i
trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)}
# after forward-backward `has_trained` is set, this condition is also checked
trainer.checkpoint_connector.has_trained = True
checkpoint_callback.on_validation_end(trainer, trainer.get_model())

file_lists = set(os.listdir(tmpdir))
Expand Down