Skip to content

Commit 6104a63

Browse files
ananthsubethanwharrisBorda
authored
[1/2] Deprecate outputs in on_train_epoch_end hooks (#7339)
* Remove outputs from on_train_epoch_end * iterate * Update callback_hook.py * update * Update training_loop.py * Update test_training_loop.py * early stop? * fix * update tests * Update test_hooks.py * Update pytorch_lightning/trainer/callback_hook.py Co-authored-by: Ethan Harris <[email protected]> * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris <[email protected]> * Update trainer.py * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent f9ff354 commit 6104a63

File tree

16 files changed

+148
-51
lines changed

16 files changed

+148
-51
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
207207

208208
### Deprecated
209209

210+
- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))
211+
210212

211213
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
212214

@@ -217,7 +219,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
217219
- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201))
218220

219221

220-
- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))
222+
- Deprecated `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))
221223

222224

223225
- Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180))

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
2929
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
3030
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
31-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
31+
from pytorch_lightning.utilities.types import STEP_OUTPUT
3232

3333
if _NATIVE_AMP_AVAILABLE:
3434
from torch.cuda.amp import GradScaler
@@ -354,12 +354,8 @@ def clip_gradients(
354354
model=self.model,
355355
)
356356

357-
def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
358-
"""Hook to do something on the end of an training epoch
359-
360-
Args:
361-
outputs: the outputs of the training steps
362-
"""
357+
def on_train_epoch_end(self) -> None:
358+
"""Hook to do something on the end of an training epoch."""
363359
pass
364360

365361
def on_train_end(self) -> None:

pytorch_lightning/callbacks/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
9898
"""Called when the train epoch begins."""
9999
pass
100100

101-
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
101+
def on_train_epoch_end(
102+
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None
103+
) -> None:
102104
"""Called when the train epoch ends."""
103105
pass
104106

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool:
161161
from pytorch_lightning.trainer.states import TrainerFn
162162
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
163163

164-
def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
164+
def on_train_epoch_end(self, trainer, pl_module) -> None:
165165
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
166166
return
167167
self._run_early_stopping_check(trainer)

pytorch_lightning/callbacks/pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul
373373
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
374374
self._original_layers[id_]["names"].append((i, name))
375375

376-
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
376+
def on_train_epoch_end(self, trainer, pl_module: LightningModule):
377377
current_epoch = trainer.current_epoch
378378
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
379379
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount

pytorch_lightning/core/hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def on_train_epoch_start(self) -> None:
235235
Called in the training loop at the very beginning of the epoch.
236236
"""
237237

238-
def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
238+
def on_train_epoch_end(self, unused: Optional = None) -> None:
239239
"""
240240
Called in the training loop at the very end of the epoch.
241241
"""

pytorch_lightning/trainer/callback_hook.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,15 @@ def on_train_epoch_end(self, outputs: EPOCH_OUTPUT):
9696
outputs: List of outputs on each ``train`` epoch
9797
"""
9898
for callback in self.callbacks:
99-
callback.on_train_epoch_end(self, self.lightning_module, outputs)
99+
if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"):
100+
warning_cache.warn(
101+
"The signature of `Callback.on_train_epoch_end` has changed in v1.3."
102+
" `outputs` parameter has been removed."
103+
" Support for the old signature will be removed in v1.5", DeprecationWarning
104+
)
105+
callback.on_train_epoch_end(self, self.lightning_module, outputs)
106+
else:
107+
callback.on_train_epoch_end(self, self.lightning_module)
100108

101109
def on_validation_epoch_start(self):
102110
"""Called when the epoch begins."""

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,11 @@ def _cache_logged_metrics(self):
12111211
self.logger_connector.cache_logged_metrics()
12121212

12131213
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
1214+
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
1215+
# This was done to manage the deprecation of an argument to on_train_epoch_end
1216+
# If making chnages to this function, ensure that those changes are also made to
1217+
# TrainLoop._on_train_epoch_end_hook
1218+
12141219
# set hook_name to model + reset Result obj
12151220
skip = self._reset_result_and_set_hook_fx_name(hook_name)
12161221

pytorch_lightning/trainer/training_loop.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pytorch_lightning.utilities.grads import grad_norm
3232
from pytorch_lightning.utilities.model_helpers import is_overridden
3333
from pytorch_lightning.utilities.parsing import AttributeDict
34+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
3435
from pytorch_lightning.utilities.warnings import WarningCache
3536

3637

@@ -197,16 +198,14 @@ def reset_train_val_dataloaders(self, model) -> None:
197198

198199
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
199200

201+
hook_overridden = self._should_add_batch_output_to_epoch_output()
202+
200203
# track the outputs to reduce at the end of the epoch
201204
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
202205
sample_output = opt_outputs[-1]
203206

204207
# decide if we need to reduce at the end of the epoch automatically
205208
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
206-
hook_overridden = (
207-
is_overridden("training_epoch_end", model=self.trainer.lightning_module)
208-
or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module)
209-
)
210209

211210
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
212211
if not (hook_overridden or auto_reduce_tng_result):
@@ -218,6 +217,22 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
218217

219218
epoch_output[opt_idx].append(opt_outputs)
220219

220+
def _should_add_batch_output_to_epoch_output(self) -> bool:
221+
# We add to the epoch outputs if
222+
# 1. The model defines training_epoch_end OR
223+
# 2. The model overrides on_train_epoch_end which has `outputs` in the signature
224+
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
225+
lightning_module = self.trainer.lightning_module
226+
if is_overridden("training_epoch_end", model=lightning_module):
227+
return True
228+
229+
if is_overridden("on_train_epoch_end", model=lightning_module):
230+
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
231+
if is_param_in_hook_signature(model_hook_fx, "outputs"):
232+
return True
233+
234+
return False
235+
221236
def get_optimizers_iterable(self, batch_idx=None):
222237
"""
223238
Generates an iterable with (idx, optimizer) for each optimizer.
@@ -593,9 +608,51 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
593608
self.trainer.logger_connector.cache_logged_metrics()
594609

595610
# call train epoch end hooks
596-
self.trainer.call_hook('on_train_epoch_end', processed_epoch_output)
611+
self._on_train_epoch_end_hook(processed_epoch_output)
597612
self.trainer.call_hook('on_epoch_end')
598613

614+
def _on_train_epoch_end_hook(self, processed_epoch_output) -> None:
615+
# We cannot rely on Trainer.call_hook because the signatures might be different across
616+
# lightning module and callback
617+
# As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`
618+
619+
# This implementation is copied from Trainer.call_hook
620+
hook_name = "on_train_epoch_end"
621+
622+
# set hook_name to model + reset Result obj
623+
skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name)
624+
625+
# always profile hooks
626+
with self.trainer.profiler.profile(hook_name):
627+
628+
# first call trainer hook
629+
if hasattr(self.trainer, hook_name):
630+
trainer_hook = getattr(self.trainer, hook_name)
631+
trainer_hook(processed_epoch_output)
632+
633+
# next call hook in lightningModule
634+
model_ref = self.trainer.lightning_module
635+
if is_overridden(hook_name, model_ref):
636+
hook_fx = getattr(model_ref, hook_name)
637+
if is_param_in_hook_signature(hook_fx, "outputs"):
638+
self.warning_cache.warn(
639+
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
640+
" `outputs` parameter has been deprecated."
641+
" Support for the old signature will be removed in v1.5", DeprecationWarning
642+
)
643+
model_ref.on_train_epoch_end(processed_epoch_output)
644+
else:
645+
model_ref.on_train_epoch_end()
646+
647+
# if the PL module doesn't have the hook then call the accelerator
648+
# used to auto-reduce things for the user with Results obj
649+
elif hasattr(self.trainer.accelerator, hook_name):
650+
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
651+
accelerator_hook()
652+
653+
if not skip:
654+
self.trainer._cache_logged_metrics()
655+
599656
def run_training_batch(self, batch, batch_idx, dataloader_idx):
600657
# track grad norms
601658
grad_norm_dic = {}

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
3434
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
3535
assert 'x' in outputs
3636

37-
def on_train_epoch_end(self, trainer, pl_module, outputs):
38-
assert len(outputs) == trainer.num_training_batches
39-
4037
class TestModel(BoringModel):
4138

4239
def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
@@ -48,7 +45,7 @@ def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx
4845
def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
4946
assert 'x' in outputs
5047

51-
def on_train_epoch_end(self, outputs) -> None:
48+
def training_epoch_end(self, outputs) -> None:
5249
assert len(outputs) == self.trainer.num_training_batches
5350

5451
model = TestModel()

0 commit comments

Comments
 (0)