Skip to content

Commit 7b45bcf

Browse files
[2/2] Remove outputs from evaluation epoch end hooks (#7338)
* Remove outputs from on_train_epoch_end * iterate * Update callback_hook.py * update * early stop? * fix * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris <[email protected]> * Update trainer.py * update * Update training_loop.py * early stop? * fix * Remove outputs from evaluation epoch end hooks * update * Update test_remove_1-5.py * fix lints * Update base.py * rm-outputs * Update evaluation_loop.py * try-save-more-memory * Update trainer.py * Update trainer.py * cache-at-start * Update evaluation_loop.py * Update training_loop.py * Update training_loop.py Co-authored-by: Ethan Harris <[email protected]>
1 parent fbcd63a commit 7b45bcf

File tree

11 files changed

+39
-252
lines changed

11 files changed

+39
-252
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8888
- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618))
8989

9090

91-
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
92-
93-
9491
- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))
9592

9693

@@ -213,6 +210,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
213210
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
214211

215212

213+
- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))
214+
215+
216216
- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))
217217

218218

pytorch_lightning/callbacks/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.optim import Optimizer
2323

2424
import pytorch_lightning as pl
25-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
25+
from pytorch_lightning.utilities.types import STEP_OUTPUT
2626

2727

2828
class Callback(abc.ABC):
@@ -108,17 +108,15 @@ def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
108108
"""Called when the val epoch begins."""
109109
pass
110110

111-
def on_validation_epoch_end(
112-
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT
113-
) -> None:
111+
def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
114112
"""Called when the val epoch ends."""
115113
pass
116114

117115
def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
118116
"""Called when the test epoch begins."""
119117
pass
120118

121-
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
119+
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
122120
"""Called when the test epoch ends."""
123121
pass
124122

pytorch_lightning/core/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.utils.data import DataLoader
2121

2222
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
23-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
23+
from pytorch_lightning.utilities.types import STEP_OUTPUT
2424

2525

2626
class ModelHooks:
@@ -245,7 +245,7 @@ def on_validation_epoch_start(self) -> None:
245245
Called in the validation loop at the very beginning of the epoch.
246246
"""
247247

248-
def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
248+
def on_validation_epoch_end(self) -> None:
249249
"""
250250
Called in the validation loop at the very end of the epoch.
251251
"""
@@ -255,7 +255,7 @@ def on_test_epoch_start(self) -> None:
255255
Called in the test loop at the very beginning of the epoch.
256256
"""
257257

258-
def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
258+
def on_test_epoch_end(self) -> None:
259259
"""
260260
Called in the test loop at the very end of the epoch.
261261
"""

pytorch_lightning/trainer/callback_hook.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,44 +111,20 @@ def on_validation_epoch_start(self):
111111
for callback in self.callbacks:
112112
callback.on_validation_epoch_start(self, self.lightning_module)
113113

114-
def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT):
115-
"""Called when the epoch ends.
116-
117-
Args:
118-
outputs: List of outputs on each ``validation`` epoch
119-
"""
114+
def on_validation_epoch_end(self):
115+
"""Called when the validation epoch ends."""
120116
for callback in self.callbacks:
121-
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
122-
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
123-
else:
124-
warning_cache.warn(
125-
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
126-
" `outputs` parameter has been added."
127-
" Support for the old signature will be removed in v1.5", DeprecationWarning
128-
)
129-
callback.on_validation_epoch_end(self, self.lightning_module)
117+
callback.on_validation_epoch_end(self, self.lightning_module)
130118

131119
def on_test_epoch_start(self):
132120
"""Called when the epoch begins."""
133121
for callback in self.callbacks:
134122
callback.on_test_epoch_start(self, self.lightning_module)
135123

136-
def on_test_epoch_end(self, outputs: EPOCH_OUTPUT):
137-
"""Called when the epoch ends.
138-
139-
Args:
140-
outputs: List of outputs on each ``test`` epoch
141-
"""
124+
def on_test_epoch_end(self):
125+
"""Called when the test epoch ends."""
142126
for callback in self.callbacks:
143-
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
144-
callback.on_test_epoch_end(self, self.lightning_module, outputs)
145-
else:
146-
warning_cache.warn(
147-
"`Callback.on_test_epoch_end` signature has changed in v1.3."
148-
" `outputs` parameter has been added."
149-
" Support for the old signature will be removed in v1.5", DeprecationWarning
150-
)
151-
callback.on_test_epoch_end(self, self.lightning_module)
127+
callback.on_test_epoch_end(self, self.lightning_module)
152128

153129
def on_predict_epoch_start(self) -> None:
154130
"""Called when the epoch begins."""

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, List, Optional, Tuple, Union
14+
from typing import Any, List, Optional, Tuple, Union
1515

1616
from torch.utils.data import DataLoader
1717

@@ -20,7 +20,6 @@
2020
from pytorch_lightning.trainer.states import TrainerFn
2121
from pytorch_lightning.trainer.supporters import PredictionCollection
2222
from pytorch_lightning.utilities.model_helpers import is_overridden
23-
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2423
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
2524
from pytorch_lightning.utilities.warnings import WarningCache
2625

@@ -76,6 +75,7 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool:
7675
return sum(max_batches) == 0
7776

7877
def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
78+
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
7979
if self.trainer.testing:
8080
self.trainer.call_hook('on_test_start', *args, **kwargs)
8181
else:
@@ -188,6 +188,13 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT
188188
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
189189
return output
190190

191+
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
192+
model = self.trainer.lightning_module
193+
if self.trainer.testing:
194+
return is_overridden('test_epoch_end', model=model)
195+
else:
196+
return is_overridden('validation_epoch_end', model=model)
197+
191198
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
192199
# unset dataloder_idx in model
193200
self.trainer.logger_connector.evaluation_epoch_end()
@@ -241,7 +248,7 @@ def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, datal
241248
# track debug metrics
242249
self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output)
243250

244-
def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None:
251+
def on_evaluation_epoch_end(self) -> None:
245252
model_ref = self.trainer.lightning_module
246253
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
247254

@@ -251,18 +258,11 @@ def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]])
251258

252259
if hasattr(self.trainer, hook_name):
253260
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
254-
on_evaluation_epoch_end_hook(outputs)
261+
on_evaluation_epoch_end_hook()
255262

256263
if is_overridden(hook_name, model_ref):
257264
model_hook_fx = getattr(model_ref, hook_name)
258-
if is_param_in_hook_signature(model_hook_fx, "outputs"):
259-
model_hook_fx(outputs)
260-
else:
261-
self.warning_cache.warn(
262-
f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added."
263-
" Support for the old signature will be removed in v1.5", DeprecationWarning
264-
)
265-
model_hook_fx()
265+
model_hook_fx()
266266

267267
self.trainer._cache_logged_metrics()
268268

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -972,22 +972,23 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
972972
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)
973973

974974
# store batch level output per dataloader
975-
self.evaluation_loop.outputs.append(dl_outputs)
975+
if self.evaluation_loop.should_track_batch_outputs_for_epoch_end:
976+
self.evaluation_loop.outputs.append(dl_outputs)
976977

977978
outputs = self.evaluation_loop.outputs
978979

979980
# reset outputs
980981
self.evaluation_loop.outputs = []
981982

982983
# with a single dataloader don't pass a 2D list
983-
if self.evaluation_loop.num_dataloaders == 1:
984+
if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1:
984985
outputs = outputs[0]
985986

986987
# lightning module method
987988
self.evaluation_loop.evaluation_epoch_end(outputs)
988989

989990
# hook
990-
self.evaluation_loop.on_evaluation_epoch_end(outputs)
991+
self.evaluation_loop.on_evaluation_epoch_end()
991992

992993
# update epoch-level lr_schedulers
993994
if on_epoch:
@@ -1212,8 +1213,8 @@ def _cache_logged_metrics(self):
12121213

12131214
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
12141215
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
1215-
# This was done to manage the deprecation of an argument to on_train_epoch_end
1216-
# If making chnages to this function, ensure that those changes are also made to
1216+
# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end
1217+
# If making changes to this function, ensure that those changes are also made to
12171218
# TrainLoop._on_train_epoch_end_hook
12181219

12191220
# set hook_name to model + reset Result obj

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -65,48 +65,6 @@ def training_epoch_end(self, outputs) -> None:
6565
trainer.fit(model)
6666

6767

68-
def test_on_val_epoch_end_outputs(tmpdir):
69-
70-
class CB(Callback):
71-
72-
def on_validation_epoch_end(self, trainer, pl_module, outputs):
73-
if trainer.running_sanity_check:
74-
assert len(outputs) == trainer.num_sanity_val_batches[0]
75-
else:
76-
assert len(outputs) == trainer.num_val_batches[0]
77-
78-
model = BoringModel()
79-
80-
trainer = Trainer(
81-
callbacks=CB(),
82-
default_root_dir=tmpdir,
83-
limit_train_batches=2,
84-
limit_val_batches=2,
85-
max_epochs=1,
86-
weights_summary=None,
87-
)
88-
89-
trainer.fit(model)
90-
91-
92-
def test_on_test_epoch_end_outputs(tmpdir):
93-
94-
class CB(Callback):
95-
96-
def on_test_epoch_end(self, trainer, pl_module, outputs):
97-
assert len(outputs) == trainer.num_test_batches[0]
98-
99-
model = BoringModel()
100-
101-
trainer = Trainer(
102-
callbacks=CB(),
103-
default_root_dir=tmpdir,
104-
weights_summary=None,
105-
)
106-
107-
trainer.test(model)
108-
109-
11068
def test_free_memory_on_eval_outputs(tmpdir):
11169

11270
class CB(Callback):

tests/callbacks/test_callbacks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
5858
call.on_validation_epoch_start(trainer, model),
5959
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
6060
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
61-
call.on_validation_epoch_end(trainer, model, ANY),
61+
call.on_validation_epoch_end(trainer, model),
6262
call.on_epoch_end(trainer, model),
6363
call.on_validation_end(trainer, model),
6464
call.on_sanity_check_end(trainer, model),
@@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
9090
call.on_validation_epoch_start(trainer, model),
9191
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
9292
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
93-
call.on_validation_epoch_end(trainer, model, ANY),
93+
call.on_validation_epoch_end(trainer, model),
9494
call.on_epoch_end(trainer, model),
9595
call.on_validation_end(trainer, model),
9696
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
@@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
128128
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
129129
call.on_test_batch_start(trainer, model, ANY, 1, 0),
130130
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
131-
call.on_test_epoch_end(trainer, model, ANY),
131+
call.on_test_epoch_end(trainer, model),
132132
call.on_epoch_end(trainer, model),
133133
call.on_test_end(trainer, model),
134134
call.teardown(trainer, model, 'test'),
@@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
163163
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
164164
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
165165
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
166-
call.on_validation_epoch_end(trainer, model, ANY),
166+
call.on_validation_epoch_end(trainer, model),
167167
call.on_epoch_end(trainer, model),
168168
call.on_validation_end(trainer, model),
169169
call.teardown(trainer, model, 'validate'),

tests/core/test_hooks.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

0 commit comments

Comments
 (0)