-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Separate epoch validation from step validation #5208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c40fee
ed6ebf1
236b052
788203b
c7b24ca
42b0c7b
0cc0254
15e09b0
c51f946
2c8ed93
5879528
2e6c601
d38dba4
465a6f4
b3d601f
d84996a
549eb89
260d1f5
99cc9f5
85af968
740a07e
ebbd980
ed8df7b
6a376c1
465579b
40bf21b
305d8f9
6feabac
cffe27f
65d0797
42276ad
bcdd9ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -145,3 +145,6 @@ pytorch\ lightning | |
| test-reports/ | ||
| wandb | ||
| .forked/ | ||
|
|
||
| # ctags | ||
| tags | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| import torch | ||
| import torch.distributed as torch_distrib | ||
|
|
||
| from pytorch_lightning.callbacks import ModelCheckpoint | ||
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | ||
| from pytorch_lightning.core.lightning import LightningModule | ||
| from pytorch_lightning.core.memory import ModelSummary | ||
| from pytorch_lightning.core.optimizer import LightningOptimizer | ||
|
|
@@ -153,7 +153,7 @@ def on_train_end(self): | |
| # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates | ||
| # when a checkpoint was saved at the last step | ||
| self.trainer.global_step -= 1 | ||
| self.check_checkpoint_callback(should_save=True, is_last=True) | ||
| self.check_checkpoint_callback(should_update=True, is_last=True) | ||
| self.trainer.global_step += 1 | ||
|
|
||
| # hook | ||
|
|
@@ -176,18 +176,27 @@ def on_train_end(self): | |
| model.cpu() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def check_checkpoint_callback(self, should_save, is_last=False): | ||
| # TODO bake this logic into the checkpoint callback | ||
| if should_save and self.trainer.checkpoint_connector.has_trained: | ||
| checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] | ||
| def check_checkpoint_callback(self, should_update, is_last=False): | ||
| # TODO bake this logic into the ModelCheckpoint callback | ||
| if should_update and self.trainer.checkpoint_connector.has_trained: | ||
| callbacks = self.trainer.checkpoint_callbacks | ||
|
|
||
| if is_last and any(c.save_last for c in checkpoint_callbacks): | ||
| if is_last and any(cb.save_last for cb in callbacks): | ||
| rank_zero_info("Saving latest checkpoint...") | ||
|
|
||
| model = self.trainer.get_model() | ||
|
|
||
| for callback in checkpoint_callbacks: | ||
| callback.on_validation_end(self.trainer, model) | ||
| for cb in callbacks: | ||
| cb.on_validation_end(self.trainer, model) | ||
|
|
||
| def check_early_stopping_callback(self, should_update): | ||
| # TODO bake this logic into the EarlyStopping callback | ||
| if should_update and self.trainer.checkpoint_connector.has_trained: | ||
| callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] | ||
| model = self.trainer.get_model() | ||
|
|
||
| for cb in callbacks: | ||
| cb.on_validation_end(self.trainer, model) | ||
|
|
||
| def on_train_epoch_start(self, epoch): | ||
|
|
||
|
|
@@ -518,7 +527,6 @@ def tbptt_split_batch(self, batch): | |
| return splits | ||
|
|
||
| def run_training_epoch(self): | ||
|
|
||
| # get model | ||
| model = self.trainer.get_model() | ||
|
|
||
|
|
@@ -531,7 +539,6 @@ def run_training_epoch(self): | |
| # enable profiling for the dataloader | ||
| train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) | ||
| dataloader_idx = 0 | ||
| should_check_val = False | ||
| for batch_idx, (batch, is_last_batch) in train_dataloader: | ||
|
|
||
| self.trainer.batch_idx = batch_idx | ||
|
|
@@ -580,11 +587,12 @@ def run_training_epoch(self): | |
| self.trainer.checkpoint_connector.has_trained = True | ||
|
|
||
| # max steps reached, end training | ||
| if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: | ||
| accumulation_done = self._accumulated_batches_reached() | ||
| # Ensure accumulation across batches has completed before breaking loop | ||
| if accumulation_done: | ||
| break | ||
| if ( | ||
| self.trainer.max_steps is not None | ||
| and self.trainer.max_steps == self.trainer.global_step + 1 | ||
| and self._accumulated_batches_reached() | ||
| ): | ||
| break | ||
|
|
||
| # end epoch early | ||
| # stop when the flag is changed or we've gone past the amount | ||
|
|
@@ -595,7 +603,7 @@ def run_training_epoch(self): | |
| self.trainer.total_batch_idx += 1 | ||
|
|
||
| # stop epoch if we limited the number of training batches | ||
| if (batch_idx + 1) >= self.trainer.num_training_batches: | ||
| if self._num_training_batches_reached(is_last_batch): | ||
| break | ||
|
|
||
| # progress global step according to grads progress | ||
|
|
@@ -612,8 +620,20 @@ def run_training_epoch(self): | |
| self.num_optimizers | ||
| ) | ||
|
|
||
| # when no val loop is present or fast-dev-run still need to call checkpoints | ||
| self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) | ||
| should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) | ||
| if should_check_val: | ||
| self.trainer.run_evaluation(on_epoch=True) | ||
| # reset stage to train | ||
| self.trainer.logger_connector.set_stage("train") | ||
|
|
||
| should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slightly confused about this part. Can you explain why we check val and then decide if we should skip it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's just to check whether there is any validation datasets available or not. If there isn't then we should run train_only_check else not. There are two cases for no validation, one when there is no
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rohitgr7 I'm debugging an issue right now related to this. this check for this could also be pointing to a bug in how |
||
| should_train_only = self.trainer.disable_validation or should_skip_eval | ||
|
|
||
| if should_train_only: | ||
| # update epoch level lr_schedulers | ||
| self.trainer.optimizer_connector.update_learning_rates(interval='epoch') | ||
| self.check_checkpoint_callback(True) | ||
| self.check_early_stopping_callback(True) | ||
|
|
||
| # increment the global step once | ||
| # progress global step according to grads progress | ||
|
|
@@ -853,25 +873,33 @@ def increment_accumulated_grad_global_step(self): | |
| def _accumulated_batches_reached(self): | ||
| return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 | ||
|
|
||
| def _num_training_batches_reached(self): | ||
| return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches | ||
| def _num_training_batches_reached(self, is_last_batch=False): | ||
| return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch | ||
|
|
||
| def should_accumulate(self): | ||
| # checks if backward or backward + optimizer step (via closure) | ||
| accumulation_done = self._accumulated_batches_reached() | ||
| is_final_batch = self._num_training_batches_reached() | ||
| return not (accumulation_done or is_final_batch) | ||
|
|
||
| def should_check_val_fx(self, batch_idx, is_last_batch): | ||
| def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): | ||
| # decide if we should run validation | ||
| is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 | ||
| is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 | ||
| can_check_val = self.trainer.enable_validation and is_val_check_epoch | ||
| should_check_val = is_val_check_batch or self.trainer.should_stop | ||
| is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") | ||
| should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) | ||
| epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches | ||
|
|
||
| should_check_val = ( | ||
| (is_val_check_batch and epoch_end_val_check) | ||
| or self.trainer.should_stop | ||
| or is_last_batch_for_infinite_dataset | ||
| ) if on_epoch else ( | ||
| is_val_check_batch | ||
| and not epoch_end_val_check | ||
| ) | ||
|
|
||
| return should_check_val | ||
| return should_check_val and can_check_val | ||
|
|
||
| def build_train_args(self, batch, batch_idx, opt_idx, hiddens): | ||
| # enable not needing to add opt_idx to training_step | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,15 +86,15 @@ def test_trainer_callback_system(torch_save): | |
| call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), | ||
| call.on_batch_end(trainer, model), | ||
| call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), | ||
| call.on_epoch_end(trainer, model), | ||
| call.on_train_epoch_end(trainer, model, ANY), | ||
|
Comment on lines
+89
to
+90
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that @williamFalcon had a point some time ago about training shall be till validation, and the example was with validation multiple times over long training...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it still works like that only if |
||
| call.on_validation_start(trainer, model), | ||
| call.on_validation_epoch_start(trainer, model), | ||
| call.on_validation_batch_start(trainer, model, ANY, 0, 0), | ||
| call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), | ||
| call.on_validation_epoch_end(trainer, model), | ||
| call.on_validation_end(trainer, model), | ||
| call.on_save_checkpoint(trainer, model), | ||
| call.on_epoch_end(trainer, model), | ||
| call.on_train_epoch_end(trainer, model, ANY), | ||
| call.on_train_end(trainer, model), | ||
| call.on_fit_end(trainer, model), | ||
| call.teardown(trainer, model, 'fit'), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.