Skip to content

Conversation

@shuyingsunshine21
Copy link
Contributor

@shuyingsunshine21 shuyingsunshine21 commented Mar 25, 2021

What does this PR do?

Note: as #6997 will fix the global_step and current epoch for training end, it will be useful for this PR. will rebase after that is checked in.
Master Issue: #6672

This is to consolidate the part for model checkpointing at the end of training.

Currently, we checkpoint based on hook on_validation_end, it is

  1. confusing
  2. might cause bug when checkpoint is based on validation metric but we limit_train_batches which prevents validation loop being called. (see related issue also: Validation not called when using an IterableDataset and limit_train_batches flag #6332) or the scenario when training failed and validation loop not called, validation metric is set as monitor (see related issue Errors within try/except of train(self) are misrepresented as checkpointing MisconfigurationException #5766)

(Note: for end of each training epoch consolidation, need some dependency cleanup, will be in separate PR)

What this PR do

  • move end of training checkpoint logic from training loop to model_checkpoint hook on_train_end
  • instead of relying on every_n_val_epochs, we provide option trigger_on_train_end to determine whether checkpoint. By default, it is turned off.
  • When trigger_on_train_end is turned on, to address the issue when monitor value is missing, we relax the condition for checking existence of monitor key at end of training. In such case, we fall back to save last.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

Shuying Sun and others added 30 commits March 23, 2021 12:06
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
…oint_consolidate

Update test_all_gather_grad.py
…1-checkpoint_consolidate"

This reverts commit c5053da, reversing
changes made to 0d23d75.
This reverts commit 70fe5da.
This reverts commit a9aae99.
@shuyingsunshine21 shuyingsunshine21 marked this pull request as ready for review April 10, 2021 07:47
@mergify mergify bot removed the has conflicts label Apr 11, 2021
@shuyingsunshine21
Copy link
Contributor Author

@ananthsub , do you think the following could be better?

trigger_on_train_end should be mutually exclusive with rest of the trigger modes (every_n_train_steps, every_n_val_epochs, ...).

and for this trigger, we only allow save_last.

Comment on lines 246 to 259
def on_train_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the trianing
"""
if not self._trigger_on_train_end:
return
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained):
if self.save_last and self.verbose:
rank_zero_info("Saving last checkpoint...")
self.save_checkpoint(trainer, is_on_train_end=True)
trainer.global_step += 1
Copy link
Contributor

@ananthsub ananthsub Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuyingsunshine21 could this directly call self._save_last_checkpoint ? I think the most common case will be for saving a last.ckpt file at the end of training. this way we don't thread through the is_on_train_end flag everywhere

@carmocca what do you think? would this go along with #6470 ?

if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained):
if self.save_last and self.verbose:
rank_zero_info("Saving last checkpoint...")
self.save_checkpoint(trainer, is_on_train_end=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub ,

can this not directly call self._save_last_checkpoint ? I think the most common case will be for saving a last.ckpt file at the end of training. this way we don't thread through the is_on_train_end flag everywhere

we could directly call self._save_last_checkpoint by ignoring the topK setup.

One thing to discuss is if trigger_on_train_end is set, should we guarantee to save last.ckpt even if save_last is not set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should respect what's set on the callback. The other reason is if we have multiple checkpoint callbacks, we don't need them all to save on train end. We'll configure only one of them to have save_last=True

@shuyingsunshine21 shuyingsunshine21 changed the title Consolidate Training End Model Checkpoint [blocked by #6997]Consolidate Training End Model Checkpoint Apr 20, 2021
Comment on lines +246 to +261
def on_train_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the trianing
"""
if not self._trigger_on_train_end:
return
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained):
if self.save_last and self.verbose:
rank_zero_info("Saving last checkpoint...")
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.global_step += 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def on_train_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the trianing
"""
if not self._trigger_on_train_end:
return
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained):
if self.save_last and self.verbose:
rank_zero_info("Saving last checkpoint...")
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.global_step += 1
def on_train_end(self, trainer, pl_module) -> None:
"""Save a checkpoint at the very end of training.
This will only save a checkpoint if `save_last` is also enabled
as the monitor metrics produced by training or validation steps or end of epochs
is not guaranteed to be available at this stage.
"""
if self._should_skip_saving_checkpoint(trainer) or not trainer.checkpoint_connector.has_trained:
return
initial_save_last = self.save_last
if self._save_on_train_end and not self.save_last:
rank_zero_warn(
"Requested to save a checkpoint at the end of training but save_last is not set. Temporarily setting save_last=True to save."
)
self.save_last = True
if self.verbose:
rank_zero_info("Saving last checkpoint...")
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.global_step += 1
self.save_last = initial_save_last

what do you think of this?

also what should happen if save_last is not set to True? should save on train end take precedence and temporarily override it? should we move the save_last check out of _save_last_checkpoint so the property needs to be checked first before we call save_last_checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the original thought is save_on_train_end is dependent on save_last, so only enabled when save_last is set also. What you proposed is to always enable is regardless of save_last. To make save_on_train_end as an independent triggering, makes sense also.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca @awaelchli what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer the current implementation, maybe throwing a warning so people know they should set both.

@awaelchli awaelchli added this to the v1.4 milestone May 3, 2021
@awaelchli awaelchli added checkpointing Related to checkpointing feature Is an improvement or enhancement labels May 3, 2021
@carmocca
Copy link
Contributor

Do we need the trigger_on_train_end flag?

@carmocca carmocca mentioned this pull request May 26, 2021
8 tasks
@shuyingsunshine21
Copy link
Contributor Author

Hi @carmocca , if my understanding is correct, your PR: #7724 would include this change. Maybe I could abandon this one?

@carmocca
Copy link
Contributor

I opened #7724 to have the full picture of what would be necessary to remove check_checkpoint_callback. But once that is validated, it can be split into smaller changes including the ones in this PR (after tweaking).

So let's keep this open for now. You can hold off on updating it and I can hijack it later and do it myself if necessary.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

checkpointing Related to checkpointing feature Is an improvement or enhancement has conflicts

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants