-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[feat] Support iteration-based checkpointing in model checkpoint callback #6146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Support iteration-based checkpointing in model checkpoint callback #6146
Conversation
|
Hello @ananthsub! Thanks for updating this PR.
Comment last updated at 2021-03-11 16:58:30 UTC |
Codecov Report
@@ Coverage Diff @@
## master #6146 +/- ##
=======================================
- Coverage 94% 92% -2%
=======================================
Files 161 161
Lines 11500 11512 +12
=======================================
- Hits 10756 10547 -209
- Misses 744 965 +221 |
34bcdc1 to
14b4b7b
Compare
carmocca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got interrupted during my review. Will re-review in another moment
| save_top_k: Optional[int] = None, | ||
| save_weights_only: bool = False, | ||
| every_n_epochs: int = 1, | ||
| every_n_batches: int = -1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should None be the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for every_n_batches? would we want every_n_epochs to be optional too?
one choice is to make both None by default, and then in the constructor, if neither are set, then we default to every_n_epochs=1 so it runs after each validation epoch.
And if someone wants to checkpoint after batches only, they can do ModelCheckpoint(..., every_n_batches=N) without needing to toggle off the every_n_epochs flag. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, both optional. The question is whether to assume they are exclusive and we go with
make both None by default, and then in the constructor, if neither are set, then we default to every_n_epochs=1 so it runs after each validation epoch.
Or not, have every_n_epochs=1 by default and have the user toggle off if necessary.
Does allowing having both active make sense to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's unlikely that both will be used together, but they aren't mutually exclusive. if you want to checkpoint every N batches, you can. if you want to checkpoint every M epochs, you can. if you set both, we'll do both
Since this regular int is more restrictive than Optional[int] wdyt about going with this approach for now, and if we get feedback that this is confusing, then we can make them both optional to maintain backwards compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me. To recap:
<0: Misconfiguration0: disabled / acts as identityn: every n
both can be set at the same time
defaults:
every_n_epochs=1every_n_steps=0
Correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca @awaelchli coming back to this, after looking at the complexity in supporting monitor and save_top_k, it makes more sense to keep them as mutually exclusive.
- So we can make both
every_n_train_stepsandevery_n_val_epochsarguments asOptional[int] - users toggle on whichever they're going to use
- we raise an exception if both are set and positive
- if neither are set, we default to every_n_train_steps = 0, every_n_val_epochs = 1
<0: Misconfiguration0: Disabledn: every n
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine with me, especially if we support many callbacks
d737694 to
d826f97
Compare
0d5bb43 to
6fffdec
Compare
|
@carmocca @tchaton @SkafteNicki i think this is mostly ready for review now. the main question i have is whether we should be using the total batch idx or the global step in the check for when to save during training. depending on that, we can set the appropriate name for the parameter (every_n_batches vs every_n_steps) |
- Make names explicit as to which hooks they apply to - Use step instead of batch for consistency with global step
Make every_n_train_steps and every_n_val_epochs mutually exclusive
make attributes private to the class
5e8de9a to
ab4012d
Compare
…ter) to github/third-party/PyTorchLightning/pytorch-lightning Summary: ### New commit log messages ## [UnReleased] - 2021-MM-DD ### Added - Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](Lightning-AI/pytorch-lightning#6667)) - Added `LightningCLI` class to provide simple reproducibility with minimum boilerplate training cli. ([#4492](Lightning-AI/pytorch-lightning#4492)) - Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](Lightning-AI/pytorch-lightning#6417)) - Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](Lightning-AI/pytorch-lightning#6123)). - Added a way to print to terminal without breaking up the progress bar ([#5470](Lightning-AI/pytorch-lightning#5470)) - Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](Lightning-AI/pytorch-lightning#6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](Lightning-AI/pytorch-lightning#6072)) - Added `RunningStage.SANITY_CHECKING` ([#4945](Lightning-AI/pytorch-lightning#4945)) - Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](Lightning-AI/pytorch-lightning#4945)) - Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](Lightning-AI/pytorch-lightning#4948)) - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](Lightning-AI/pytorch-lightning#5915)) - Added `teardown()` hook to LightningDataModule ([#4673](Lightning-AI/pytorch-lightning#4673)) - Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](Lightning-AI/pytorch-lightning#6277)) - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](Lightning-AI/pytorch-lightning#6274)) - Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](Lightning-AI/pytorch-lightning#6370)) - Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](Lightning-AI/pytorch-lightning#6633)) - Added no return warning to predict ([#6139](Lightning-AI/pytorch-lightning#6139)) - Added `Trainer.predict` config validation ([#6543](Lightning-AI/pytorch-lightning#6543)) - Added `AbstractProfiler` interface ([#6621](Lightning-AI/pytorch-lightning#6621)) - Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](Lightning-AI/pytorch-lightning#6349)) - Added support for the PyTorch 1.8.1 autograd profiler ([#6618](Lightning-AI/pytorch-lightning#6618)) - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](Lightning-AI/pytorch-lightning#6120)) - Added `configure_sharded_model` hook ([#6679](Lightning-AI/pytorch-lightning#6679)) - Added support for `precision=64`, enabling training with double precision ([#6595](Lightning-AI/pytorch-lightning#6595)) - Added support for DDP communication hooks ([#6736](Lightning-AI/pytorch-lightning#6736)) - Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](Lightning-AI/pytorch-lightning#6677)) - Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](Lightning-AI/pytorch-lightning#6764)) ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](Lightning-AI/pytorch-lightning#6259)) - Refactor `RunningStage` and `TrainerState` usage ([#4945](Lightning-AI/pytorch-lightning#4945)) - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](Lightning-AI/pytorch-lightning#4945)) - Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](Lightning-AI/pytorch-lightning#6386)) - Changed profilers to save separate report files per state and rank ([#6621](Lightning-AI/pytorch-lightning#6621)) - Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](Lightning-AI/pytorch-lightning#6349)) ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](Lightning-AI/pytorch-lightning#6146)) - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](Lightning-AI/pytorch-lightning#4945)) - Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](Lightning-AI/pytorch-lightning#6621)) - Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](Lightning-AI/pytorch-lightning#6349)) - Deprecated metrics in favor of `torchmetrics` ([#6505](Lightning-AI/pytorch-lightning#6505), [#6530](Lightning-AI/pytorch-lightning#6530), [#6540](Lightning-AI/pytorch-lightning#6540), [#6547](Lightning-AI/pytorch-lightning#6547), [#6515](Lightning-AI/pytorch-lightning#6515), [#6572](Lightning-AI/pytorch-lightning#6572), [#6573](Lightning-AI/pytorch-lightning#6573), [#6584](Lightning-AI/pytorch-lightning#6584), [#6636](Lightning-AI/pytorch-lightning#6636), [#6637](Lightning-AI/pytorch-lightning#6637), [#6649](Lightning-AI/pytorch-lightning#6649), [#6659](Lightning-AI/pytorch-lightning#6659), ) ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](Lightning-AI/pytorch-lightning#6164)) - Removed no return warning from val/test step ([#6139](Lightning-AI/pytorch-lightning#6139)) - Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](Lightning-AI/pytorch-lightning#6166)) - Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](Lightning-AI/pytorch-lightning#6163)) - Removed deprecated metrics ([#6161](Lightning-AI/pytorch-lightning#6161)) * from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve` * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` - Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](Lightning-AI/pytorch-lightning#6162)) - Removed `mode='auto'` from `EarlyStopping` ([#6167](Lightning-AI/pytorch-lightning#6167)) - Removed legacy references for magic keys in the `Result` object ([#6016](Lightning-AI/pytorch-lightning#6016)) - Removed deprecated `LightningModule` `hparams` setter ([#6207](Lightning-AI/pytorch-lightning#6207)) - Removed legacy code to log or include metrics in the progress bar by returning them in a dict with the `"log"/"progress_bar"` magic keys. Use `self.log` instead ([#6734](Lightning-AI/pytorch-lightning#6734)) - Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](Lightning-AI/pytorch-lightning#6093)) ### Fixed - Set better defaults for `rank_zero_only.rank` when training is launched with SLURM and torchelastic ([#6802](Lightning-AI/pytorch-lightning#6802)) - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](Lightning-AI/pytorch-lightning#6011)) - Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](Lightning-AI/pytorch-lightning#6070)) - Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](Lightning-AI/pytorch-lightning#6109)) - Fixed csv extension check ([#6436](Lightning-AI/pytorch-lightning#6436)) - Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](Lightning-AI/pytorch-lightning#6136)) - Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](Lightning-AI/pytorch-lightning#6136)) - Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](Lightning-AI/pytorch-lightning#6386)) - Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](Lightning-AI/pytorch-lightning#6386)) - Fixed LightningModule `all_gather` on cpu tensors ([#6416](Lightning-AI/pytorch-lightning#6416)) - Fixed torch distributed not available in setup hook for DDP ([#6506](Lightning-AI/pytorch-lightning#6506)) - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](Lightning-AI/pytorch-lightning#6705)) ## [1.2.7] - 2021-04-06 ### Fixed - Fixed resolve a bug with omegaconf and xm.save ([#6741](Lightning-AI/pytorch-lightning#6741)) - Fixed an issue with IterableDataset when __len__ is not defined ([#6828](Lightning-AI/pytorch-lightning#6828)) - Sanitize None params during pruning ([#6836](Lightning-AI/pytorch-lightning#6836)) - Enforce an epoch scheduler interval when using SWA ([#6588](Lightning-AI/pytorch-lightning#6588)) - Fixed TPU Colab hang issue, post training ([#6816](Lightning-AI/pytorch-lightning#6816)) - Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](Lightning-AI/pytorch-lightning#6730)) ## [1.2.6] - 2021-03-30 ### Changed - Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](Lightning-AI/pytorch-lightning#6498)) ### Removed - Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](Lightning-AI/pytorch-lightning#6682)) ### Fixed - Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](Lightning-AI/pytorch-lightning#6398)) - Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](Lightning-AI/pytorch-lightning#6654)) - Fixed `trainer.test` freeze on TPUs ([#6654](Lightning-AI/pytorch-lightning#6654)) - Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](Lightning-AI/pytorch-lightning#6657)) - Fixed bug where no TPUs were detected in a TPU pod env ([#6719](Lightning-AI/pytorch-lightning#6719)) ## [1.2.5] - 2021-03-23 ### Changed - Update Gradient Clipping for the TPU Accelerator ([#6576](Lightning-AI/pytorch-lightning#6576)) - Refactored setup for typing friendly ([#6590](Lightning-AI/pytorch-lightning#6590)) ### Fixed - Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](Lightning-AI/pytorch-lightning#6587)) - Fixed comparing required versions ([#6434](Lightning-AI/pytorch-lightning#6434)) - Fixed duplicate logs appearing in console when using the python logging module ([#6275](Lightning-AI/pytorch-lightning#6275)) - Added Autocast in validation, test and predict modes for Native AMP ([#6565](Lightning-AI/pytorch-lightning#6565)) Reviewed By: shuyingsunshine21 Differential Revision: D27528929 fbshipit-source-id: 311c88f71461c2c79bbf185e28d7a6d683ccc26f
What does this PR do?
Fixes #2534
This supports checkpointing periodically after a user-specified number of training batches to complement the end of validation epoch-based checkpointing currently supported. This is useful for large training runs and for purely iteration-based training.
We checkpoint after training batches because the model state doesn't change during validation. This means the training metrics will be used for the monitor keys if configured to checkpoint this way.
To do so:
every_n_train_stepsto the callback constructor which dictates the frequency with which we'll checkpoint. Ifevery_n_train_steps = 0then this feature is disabled.periodname is ambiguous with this addition: therefore we introduce a new argumentevery_n_val_epochsto be more clear as to when it's applied and markperiodas deprecatedBefore submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃