Skip to content

Commit ba722e2

Browse files
swethmandavaBorda
authored andcommitted
passing batch outputs to on_train_batch_end (#4369)
* passing batch outputs to on_train_batch_end * styling * updating epoch end logic * also condition on on_train_epoch_end hooks * more readable * pep8 * pep8 * readability suggestion accepted Co-authored-by: Jirka Borovec <[email protected]> * adding test_training_epoch_end_metrics_collection_on_override test * fix formatting * fix formatting Co-authored-by: Swetha Mandava <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Roger Shieh <[email protected]> (cherry picked from commit 5fcca4e)
1 parent cc7410b commit ba722e2

File tree

2 files changed

+79
-19
lines changed

2 files changed

+79
-19
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,13 @@ def on_train_epoch_start(self, epoch):
241241
self.trainer.call_hook("on_epoch_start")
242242
self.trainer.call_hook("on_train_epoch_start")
243243

244-
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
244+
def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
245245
# hook
246-
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
246+
self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx)
247247
self.trainer.call_hook('on_batch_end')
248248

249249
# figure out what to track for epoch end
250-
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
250+
self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs)
251251

252252
# reset batch logger internals
253253
self.trainer.logger_connector.on_train_batch_end()
@@ -259,12 +259,27 @@ def reset_train_val_dataloaders(self, model):
259259
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
260260
self.trainer.reset_val_dataloader(model)
261261

262-
def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
262+
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
263+
263264
# track the outputs to reduce at the end of the epoch
264-
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
265+
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
266+
sample_output = opt_outputs[-1]
267+
268+
# decide if we need to reduce at the end of the epoch automatically
269+
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
270+
hook_overridden = (
271+
is_overridden("training_epoch_end", model=self.trainer.get_model()) or
272+
is_overridden("on_train_epoch_end", model=self.trainer.get_model())
273+
)
274+
275+
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
276+
if not(hook_overridden or auto_reduce_tng_result):
277+
continue
278+
265279
# with 1 step (no tbptt) don't use a sequence at epoch end
266280
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
267281
opt_outputs = opt_outputs[0]
282+
268283
epoch_output[opt_idx].append(opt_outputs)
269284

270285
def get_optimizers_iterable(self):
@@ -548,17 +563,14 @@ def run_training_epoch(self):
548563
if batch_output.signal == -1:
549564
break
550565

551-
# only track outputs when user implements training_epoch_end
552-
# otherwise we will build up unnecessary memory
553-
epoch_end_outputs = self.process_train_step_outputs(
566+
batch_end_outputs = self.process_train_step_outputs(
554567
batch_output.training_step_output_for_epoch_end,
555568
self.early_stopping_accumulator,
556569
self.checkpoint_accumulator,
557570
)
558-
559571
# hook
560572
# TODO: add outputs to batches
561-
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
573+
self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx)
562574

563575
# -----------------------------------------
564576
# SAVE METRICS TO LOGGERS
@@ -896,7 +908,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
896908
# the training step outputs a list per optimizer. The list contains the outputs at each time step
897909
# when no TBPTT is used, then the list has 1 item per batch
898910
# when TBPTT IS used, then the list has n items (1 per time step)
899-
epoch_end_outputs = []
911+
batch_end_outputs = []
900912
for optimizer_idx_outputs in all_train_step_outputs:
901913
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
902914
if len(optimizer_idx_outputs) == 0:
@@ -911,14 +923,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
911923
if isinstance(sample_output, dict) and "checkpoint_on" in sample_output:
912924
checkpoint_accumulator.accumulate(sample_output["checkpoint_on"])
913925

914-
# decide if we need to reduce at the end of the epoch automatically
915-
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
916-
917-
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
918-
if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result:
919-
epoch_end_outputs.append(optimizer_idx_outputs)
926+
batch_end_outputs.append(optimizer_idx_outputs)
920927

921-
return epoch_end_outputs
928+
return batch_end_outputs
922929

923930
def prepare_optimizers(self):
924931
# in manual optimization we loop over all optimizers at once

tests/models/test_hooks.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import pytest
1919
import torch
2020

21-
from pytorch_lightning import Trainer
21+
22+
from pytorch_lightning import Trainer, Callback
2223
from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator
2324
from pytorch_lightning.trainer.states import TrainerState
2425
from tests.base import BoringModel, EvalModelTemplate, RandomDataset
@@ -91,6 +92,58 @@ def training_epoch_end(self, outputs):
9192
assert metrics[f'epoch_metric_{i}'] == i
9293

9394

95+
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
96+
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
97+
num_epochs = 1
98+
99+
class LoggingCallback(Callback):
100+
101+
def on_train_epoch_end(self, trainer, pl_module):
102+
self.len_outputs = 0
103+
104+
def on_train_epoch_end(self, trainer, pl_module, outputs):
105+
self.len_outputs = len(outputs[0])
106+
107+
class OverriddenModel(EvalModelTemplate):
108+
109+
def on_train_epoch_start(self):
110+
self.num_train_batches = 0
111+
112+
def training_epoch_end(self, outputs): # Overridden
113+
pass
114+
return
115+
116+
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
117+
self.num_train_batches += 1
118+
119+
class NotOverriddenModel(EvalModelTemplate):
120+
121+
def on_train_epoch_start(self):
122+
self.num_train_batches = 0
123+
124+
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
125+
self.num_train_batches += 1
126+
127+
overridden_model = OverriddenModel()
128+
not_overridden_model = NotOverriddenModel()
129+
130+
callback = LoggingCallback()
131+
trainer = Trainer(
132+
max_epochs=num_epochs,
133+
default_root_dir=tmpdir,
134+
overfit_batches=2,
135+
callbacks=[callback],
136+
)
137+
138+
result = trainer.fit(overridden_model)
139+
assert callback.len_outputs == overridden_model.num_train_batches
140+
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
141+
142+
result = trainer.fit(not_overridden_model)
143+
assert callback.len_outputs == 0
144+
# outputs from on_train_batch_end should be empty
145+
146+
94147
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
95148
def test_transfer_batch_hook():
96149

0 commit comments

Comments
 (0)