Skip to content

Commit 47c47fa

Browse files
authored
Remove outputs in on_train_epoch_end hooks (#8587)
1 parent 470842f commit 47c47fa

File tree

11 files changed

+16
-142
lines changed

11 files changed

+16
-142
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161
-
6262

6363

64-
-
64+
- Removed the `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587))
6565

6666

6767
-

docs/source/starter/new-project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ Here's an example adding a not-so-fancy learning rate decay rule:
602602
group = [param_group['lr'] for param_group in optimizer.param_groups]
603603
self.old_lrs.append(group)
604604

605-
def on_train_epoch_end(self, trainer, pl_module, outputs):
605+
def on_train_epoch_end(self, trainer, pl_module):
606606
for opt_idx, optimizer in enumerate(trainer.optimizers):
607607
old_lr_group = self.old_lrs[opt_idx]
608608
new_lr_group = []

pytorch_lightning/callbacks/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
9494
"""Called when the train epoch begins."""
9595
pass
9696

97-
def on_train_epoch_end(
98-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
99-
) -> None:
97+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
10098
"""Called when the train epoch ends.
10199
102100
To access all batch outputs at the end of the epoch, either:

pytorch_lightning/callbacks/pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _run_pruning(self, current_epoch: int) -> None:
398398
):
399399
self.apply_lottery_ticket_hypothesis()
400400

401-
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: # type: ignore
401+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
402402
if self._prune_on_train_epoch_end:
403403
rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning")
404404
self._run_pruning(pl_module.current_epoch)

pytorch_lightning/core/hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def on_train_epoch_start(self) -> None:
234234
Called in the training loop at the very beginning of the epoch.
235235
"""
236236

237-
def on_train_epoch_end(self, unused: Optional = None) -> None:
237+
def on_train_epoch_end(self) -> None:
238238
"""
239239
Called in the training loop at the very end of the epoch.
240240

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424
from pytorch_lightning.utilities.model_helpers import is_overridden
25-
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2625
from pytorch_lightning.utilities.types import STEP_OUTPUT
2726
from pytorch_lightning.utilities.warnings import WarningCache
2827

@@ -227,7 +226,7 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
227226
self.trainer.fit_loop.epoch_progress.increment_processed()
228227

229228
# call train epoch end hooks
230-
self._on_train_epoch_end_hook(processed_outputs)
229+
self.trainer.call_hook("on_train_epoch_end")
231230
self.trainer.call_hook("on_epoch_end")
232231
self.trainer.logger_connector.on_epoch_end()
233232

@@ -250,47 +249,6 @@ def _run_validation(self):
250249
with torch.no_grad():
251250
self.val_loop.run()
252251

253-
def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:
254-
"""Runs ``on_train_epoch_end hook``."""
255-
# We cannot rely on Trainer.call_hook because the signatures might be different across
256-
# lightning module and callback
257-
# As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`
258-
259-
# This implementation is copied from Trainer.call_hook
260-
hook_name = "on_train_epoch_end"
261-
prev_fx_name = self.trainer.lightning_module._current_fx_name
262-
self.trainer.lightning_module._current_fx_name = hook_name
263-
264-
# always profile hooks
265-
with self.trainer.profiler.profile(hook_name):
266-
267-
# first call trainer hook
268-
if hasattr(self.trainer, hook_name):
269-
trainer_hook = getattr(self.trainer, hook_name)
270-
trainer_hook(processed_epoch_output)
271-
272-
# next call hook in lightningModule
273-
model_ref = self.trainer.lightning_module
274-
if is_overridden(hook_name, model_ref):
275-
hook_fx = getattr(model_ref, hook_name)
276-
if is_param_in_hook_signature(hook_fx, "outputs"):
277-
self._warning_cache.deprecation(
278-
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
279-
" `outputs` parameter has been deprecated."
280-
" Support for the old signature will be removed in v1.5"
281-
)
282-
model_ref.on_train_epoch_end(processed_epoch_output)
283-
else:
284-
model_ref.on_train_epoch_end()
285-
286-
# call the accelerator hook
287-
if hasattr(self.trainer.accelerator, hook_name):
288-
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
289-
accelerator_hook()
290-
291-
# restore current_fx when nested context
292-
self.trainer.lightning_module._current_fx_name = prev_fx_name
293-
294252
def _accumulated_batches_reached(self) -> bool:
295253
"""Determine if accumulation will be finished by the end of the current batch."""
296254
return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0
@@ -313,7 +271,7 @@ def _track_epoch_end_reduce_metrics(
313271
self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT
314272
) -> None:
315273
"""Adds the batch outputs to the epoch outputs and prepares reduction"""
316-
hook_overridden = self._should_add_batch_output_to_epoch_output()
274+
hook_overridden = is_overridden("training_epoch_end", self.trainer.lightning_module)
317275
if not hook_overridden:
318276
return
319277

@@ -329,24 +287,6 @@ def _track_epoch_end_reduce_metrics(
329287

330288
epoch_output[opt_idx].append(opt_outputs)
331289

332-
def _should_add_batch_output_to_epoch_output(self) -> bool:
333-
"""
334-
We add to the epoch outputs if
335-
1. The model defines training_epoch_end OR
336-
2. The model overrides on_train_epoch_end which has `outputs` in the signature
337-
"""
338-
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
339-
lightning_module = self.trainer.lightning_module
340-
if is_overridden("training_epoch_end", lightning_module):
341-
return True
342-
343-
if is_overridden("on_train_epoch_end", lightning_module):
344-
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
345-
if is_param_in_hook_signature(model_hook_fx, "outputs"):
346-
return True
347-
348-
return False
349-
350290
@staticmethod
351291
def _prepare_outputs(
352292
outputs: List[List[List["ResultCollection"]]], batch_mode: bool

pytorch_lightning/trainer/callback_hook.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
import pytorch_lightning as pl
2323
from pytorch_lightning.callbacks import Callback
2424
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
25-
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
26-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
27-
from pytorch_lightning.utilities.warnings import WarningCache
28-
29-
warning_cache = WarningCache()
25+
from pytorch_lightning.utilities.types import STEP_OUTPUT
3026

3127

3228
class TrainerCallbackHookMixin(ABC):
@@ -91,22 +87,10 @@ def on_train_epoch_start(self):
9187
for callback in self.callbacks:
9288
callback.on_train_epoch_start(self, self.lightning_module)
9389

94-
def on_train_epoch_end(self, outputs: EPOCH_OUTPUT):
95-
"""Called when the epoch ends.
96-
97-
Args:
98-
outputs: List of outputs on each ``train`` epoch
99-
"""
90+
def on_train_epoch_end(self):
91+
"""Called when the epoch ends."""
10092
for callback in self.callbacks:
101-
if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"):
102-
warning_cache.deprecation(
103-
"The signature of `Callback.on_train_epoch_end` has changed in v1.3."
104-
" `outputs` parameter has been removed."
105-
" Support for the old signature will be removed in v1.5"
106-
)
107-
callback.on_train_epoch_end(self, self.lightning_module, outputs)
108-
else:
109-
callback.on_train_epoch_end(self, self.lightning_module)
93+
callback.on_train_epoch_end(self, self.lightning_module)
11094

11195
def on_validation_epoch_start(self):
11296
"""Called when the epoch begins."""

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,10 +1209,6 @@ def _call_teardown_hook(self, model: "pl.LightningModule") -> None:
12091209
model._metric_attributes = None
12101210

12111211
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
1212-
# Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook
1213-
# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end
1214-
# If making changes to this function, ensure that those changes are also made to
1215-
# TrainingEpochLoop._on_train_epoch_end_hook
12161212
if self.lightning_module:
12171213
prev_fx_name = self.lightning_module._current_fx_name
12181214
self.lightning_module._current_fx_name = hook_name

tests/deprecated_api/test_remove_1-5.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pytorch_lightning.loggers import WandbLogger
2828
from pytorch_lightning.plugins import DeepSpeedPlugin
2929
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler
30-
from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache
3130
from pytorch_lightning.utilities import device_parser
3231
from pytorch_lightning.utilities.imports import _compare_version
3332
from tests.deprecated_api import no_deprecated_call
@@ -194,49 +193,6 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
194193
ModelCheckpoint(dirpath=tmpdir, period=1)
195194

196195

197-
def test_v1_5_0_old_on_train_epoch_end(tmpdir):
198-
callback_warning_cache.clear()
199-
200-
class OldSignature(Callback):
201-
def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa
202-
...
203-
204-
class OldSignatureModel(BoringModel):
205-
def on_train_epoch_end(self, outputs): # noqa
206-
...
207-
208-
model = BoringModel()
209-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
210-
211-
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
212-
trainer.fit(model)
213-
214-
callback_warning_cache.clear()
215-
216-
model = OldSignatureModel()
217-
218-
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
219-
trainer.fit(model)
220-
221-
trainer.fit_loop.epoch_loop._warning_cache.clear()
222-
223-
class NewSignature(Callback):
224-
def on_train_epoch_end(self, trainer, pl_module):
225-
...
226-
227-
trainer.callbacks = [NewSignature()]
228-
with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."):
229-
trainer.fit(model)
230-
231-
class NewSignatureModel(BoringModel):
232-
def on_train_epoch_end(self):
233-
...
234-
235-
model = NewSignatureModel()
236-
with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."):
237-
trainer.fit(model)
238-
239-
240196
@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
241197
def test_v1_5_0_profiler_output_filename(tmpdir, cls):
242198
filepath = str(tmpdir / "test.txt")

tests/models/test_hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,11 @@ def training_step(self, batch, batch_idx):
534534
dict(name="train", args=(True,)),
535535
dict(name="on_validation_model_train"),
536536
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
537-
dict(name="Callback.on_train_epoch_end", args=(trainer, model, [dict(loss=ANY)] * train_batches)),
537+
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
538538
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
539539
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
540540
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
541-
dict(name="on_train_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
541+
dict(name="on_train_epoch_end"),
542542
dict(name="Callback.on_epoch_end", args=(trainer, model)),
543543
dict(name="on_epoch_end"),
544544
dict(name="Callback.on_train_end", args=(trainer, model)),
@@ -635,10 +635,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
635635
# TODO: wrong current epoch after reload
636636
*model._train_batch(trainer, model, train_batches, current_epoch=1),
637637
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
638-
dict(name="Callback.on_train_epoch_end", args=(trainer, model, [dict(loss=ANY)] * train_batches)),
638+
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
639639
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
640640
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
641-
dict(name="on_train_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
641+
dict(name="on_train_epoch_end"),
642642
dict(name="Callback.on_epoch_end", args=(trainer, model)),
643643
dict(name="on_epoch_end"),
644644
dict(name="Callback.on_train_end", args=(trainer, model)),

0 commit comments

Comments
 (0)