-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[1/2] Deprecate outputs in on_train_epoch_end hooks
#7339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
be39d5f
082c3cc
495c0b7
03fa99d
2f55ff1
8d262e1
b0c02cb
274d5f8
5c6a3d8
01aed79
b719bef
66f310a
d18455c
f2f8b58
c84577d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| from pytorch_lightning.utilities.grads import grad_norm | ||
| from pytorch_lightning.utilities.model_helpers import is_overridden | ||
| from pytorch_lightning.utilities.parsing import AttributeDict | ||
| from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature | ||
| from pytorch_lightning.utilities.warnings import WarningCache | ||
|
|
||
|
|
||
|
|
@@ -197,16 +198,14 @@ def reset_train_val_dataloaders(self, model) -> None: | |
|
|
||
| def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): | ||
|
|
||
| hook_overridden = self._should_add_batch_output_to_epoch_output() | ||
|
|
||
| # track the outputs to reduce at the end of the epoch | ||
| for opt_idx, opt_outputs in enumerate(batch_end_outputs): | ||
| sample_output = opt_outputs[-1] | ||
|
|
||
| # decide if we need to reduce at the end of the epoch automatically | ||
| auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end | ||
| hook_overridden = ( | ||
| is_overridden("training_epoch_end", model=self.trainer.lightning_module) | ||
| or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) | ||
| ) | ||
|
|
||
| # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end | ||
| if not (hook_overridden or auto_reduce_tng_result): | ||
|
|
@@ -218,6 +217,22 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): | |
|
|
||
| epoch_output[opt_idx].append(opt_outputs) | ||
|
|
||
| def _should_add_batch_output_to_epoch_output(self) -> bool: | ||
| # We add to the epoch outputs if | ||
| # 1. The model defines training_epoch_end OR | ||
| # 2. The model overrides on_train_epoch_end which has `outputs` in the signature | ||
| # TODO: in v1.5 this only needs to check if training_epoch_end is overridden | ||
| lightning_module = self.trainer.lightning_module | ||
| if is_overridden("training_epoch_end", model=lightning_module): | ||
| return True | ||
|
|
||
| if is_overridden("on_train_epoch_end", model=lightning_module): | ||
| model_hook_fx = getattr(lightning_module, "on_train_epoch_end") | ||
| if is_param_in_hook_signature(model_hook_fx, "outputs"): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def get_optimizers_iterable(self, batch_idx=None): | ||
| """ | ||
| Generates an iterable with (idx, optimizer) for each optimizer. | ||
|
|
@@ -593,9 +608,51 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: | |
| self.trainer.logger_connector.cache_logged_metrics() | ||
|
|
||
| # call train epoch end hooks | ||
| self.trainer.call_hook('on_train_epoch_end', processed_epoch_output) | ||
| self._on_train_epoch_end_hook(processed_epoch_output) | ||
| self.trainer.call_hook('on_epoch_end') | ||
|
|
||
| def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: | ||
| # We cannot rely on Trainer.call_hook because the signatures might be different across | ||
| # lightning module and callback | ||
| # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` | ||
|
|
||
| # This implementation is copied from Trainer.call_hook | ||
| hook_name = "on_train_epoch_end" | ||
|
|
||
| # set hook_name to model + reset Result obj | ||
| skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name) | ||
|
|
||
| # always profile hooks | ||
| with self.trainer.profiler.profile(hook_name): | ||
|
|
||
| # first call trainer hook | ||
| if hasattr(self.trainer, hook_name): | ||
| trainer_hook = getattr(self.trainer, hook_name) | ||
| trainer_hook(processed_epoch_output) | ||
|
|
||
| # next call hook in lightningModule | ||
| model_ref = self.trainer.lightning_module | ||
| if is_overridden(hook_name, model_ref): | ||
| hook_fx = getattr(model_ref, hook_name) | ||
| if is_param_in_hook_signature(hook_fx, "outputs"): | ||
| self.warning_cache.warn( | ||
| "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." | ||
| " `outputs` parameter has been deprecated." | ||
| " Support for the old signature will be removed in v1.5", DeprecationWarning | ||
| ) | ||
| model_ref.on_train_epoch_end(processed_epoch_output) | ||
| else: | ||
| model_ref.on_train_epoch_end() | ||
|
|
||
| # if the PL module doesn't have the hook then call the accelerator | ||
| # used to auto-reduce things for the user with Results obj | ||
| elif hasattr(self.trainer.accelerator, hook_name): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a huge fan of this. Better to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from the comment, call_hook enforces that all of accelerator/trainer/module all take the exact same arguments for the hook, which might not be the case here. this was the same pattern @kaushikb11 followed in #6120 I'm not really a fan either, but maybe this is something we can look at for v1.4 is how to make to simplify/strengthen this? maybe the techniques @SkafteNicki used for metrics collections could apply here, but that seems beyond the scope of this PR one thing I can do is add comments to Trainer.call_hook to indicate that there's this override being applied in training loop and any changes to call_hook must also be applied here. |
||
| accelerator_hook = getattr(self.trainer.accelerator, hook_name) | ||
| accelerator_hook() | ||
|
|
||
| if not skip: | ||
| self.trainer._cache_logged_metrics() | ||
|
|
||
| def run_training_batch(self, batch, batch_idx, dataloader_idx): | ||
| # track grad norms | ||
| grad_norm_dic = {} | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.