-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[WIP] Check monitor for checkpoints every epoch even if there is no validation #4793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
42b1f4e
7f12291
b2e06bb
e4e825a
e0aacea
c62a943
3460198
05a121e
ddf6b62
51db066
b37e51c
653a4c2
3a4be6d
8d01da8
8550b5c
68cafd4
92d3b4a
bb47a22
b3e4c5a
0225426
b880ad8
0540544
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,12 +156,6 @@ def on_train_end(self): | |
|
|
||
| self._teardown_already_run = True | ||
|
|
||
| # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates | ||
| # when a checkpoint was saved at the last step | ||
| self.trainer.global_step -= 1 | ||
| self.check_checkpoint_callback(should_save=True, is_last=True) | ||
| self.trainer.global_step += 1 | ||
|
|
||
| # hook | ||
| self.trainer.call_hook("on_train_end") | ||
|
|
||
|
|
@@ -182,19 +176,6 @@ def on_train_end(self): | |
| model.cpu() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def check_checkpoint_callback(self, should_save, is_last=False): | ||
| # TODO bake this logic into the checkpoint callback | ||
| if should_save and self.trainer.checkpoint_connector.has_trained: | ||
| checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] | ||
|
|
||
| if is_last and any(c.save_last for c in checkpoint_callbacks): | ||
| rank_zero_info("Saving latest checkpoint...") | ||
|
|
||
| model = self.trainer.get_model() | ||
|
|
||
| for callback in checkpoint_callbacks: | ||
| callback.on_validation_end(self.trainer, model) | ||
|
|
||
| def on_train_epoch_start(self, epoch): | ||
|
|
||
| # update training progress in trainer | ||
|
|
@@ -606,9 +587,6 @@ def run_training_epoch(self): | |
| self.num_optimizers | ||
| ) | ||
|
|
||
| # when no val loop is present or fast-dev-run still need to call checkpoints | ||
| self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can just fix your usecase by adding
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest yes since doing a bit of a refactor there to fix more issues. Your use-case is already fixed there. Mind check if it works for you?? |
||
|
|
||
| # increment the global step once | ||
| # progress global step according to grads progress | ||
| self.increment_accumulated_grad_global_step() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe merge together this function and
_is_valid_monitor_key?