Skip to content

Commit 5b72d6a

Browse files
committed
Remove outputs from evaluation epoch end hooks
1 parent 6c75467 commit 5b72d6a

File tree

9 files changed

+18
-239
lines changed

9 files changed

+18
-239
lines changed

pytorch_lightning/callbacks/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,15 @@ def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
108108
"""Called when the val epoch begins."""
109109
pass
110110

111-
def on_validation_epoch_end(
112-
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT
113-
) -> None:
111+
def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
114112
"""Called when the val epoch ends."""
115113
pass
116114

117115
def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
118116
"""Called when the test epoch begins."""
119117
pass
120118

121-
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
119+
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
122120
"""Called when the test epoch ends."""
123121
pass
124122

pytorch_lightning/core/hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def on_validation_epoch_start(self) -> None:
245245
Called in the validation loop at the very beginning of the epoch.
246246
"""
247247

248-
def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
248+
def on_validation_epoch_end(self) -> None:
249249
"""
250250
Called in the validation loop at the very end of the epoch.
251251
"""
@@ -255,7 +255,7 @@ def on_test_epoch_start(self) -> None:
255255
Called in the test loop at the very beginning of the epoch.
256256
"""
257257

258-
def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
258+
def on_test_epoch_end(self) -> None:
259259
"""
260260
Called in the test loop at the very end of the epoch.
261261
"""

pytorch_lightning/trainer/callback_hook.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,44 +111,20 @@ def on_validation_epoch_start(self):
111111
for callback in self.callbacks:
112112
callback.on_validation_epoch_start(self, self.lightning_module)
113113

114-
def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT):
115-
"""Called when the epoch ends.
116-
117-
Args:
118-
outputs: List of outputs on each ``validation`` epoch
119-
"""
114+
def on_validation_epoch_end(self):
115+
"""Called when the validation epoch ends."""
120116
for callback in self.callbacks:
121-
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
122-
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
123-
else:
124-
warning_cache.warn(
125-
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
126-
" `outputs` parameter has been added."
127-
" Support for the old signature will be removed in v1.5", DeprecationWarning
128-
)
129-
callback.on_validation_epoch_end(self, self.lightning_module)
117+
callback.on_validation_epoch_end(self, self.lightning_module)
130118

131119
def on_test_epoch_start(self):
132120
"""Called when the epoch begins."""
133121
for callback in self.callbacks:
134122
callback.on_test_epoch_start(self, self.lightning_module)
135123

136-
def on_test_epoch_end(self, outputs: EPOCH_OUTPUT):
137-
"""Called when the epoch ends.
138-
139-
Args:
140-
outputs: List of outputs on each ``test`` epoch
141-
"""
124+
def on_test_epoch_end(self):
125+
"""Called when the test epoch ends."""
142126
for callback in self.callbacks:
143-
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
144-
callback.on_test_epoch_end(self, self.lightning_module, outputs)
145-
else:
146-
warning_cache.warn(
147-
"`Callback.on_test_epoch_end` signature has changed in v1.3."
148-
" `outputs` parameter has been added."
149-
" Support for the old signature will be removed in v1.5", DeprecationWarning
150-
)
151-
callback.on_test_epoch_end(self, self.lightning_module)
127+
callback.on_test_epoch_end(self, self.lightning_module)
152128

153129
def on_predict_epoch_start(self) -> None:
154130
"""Called when the epoch begins."""

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,7 @@ def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]])
255255

256256
if is_overridden(hook_name, model_ref):
257257
model_hook_fx = getattr(model_ref, hook_name)
258-
if is_param_in_hook_signature(model_hook_fx, "outputs"):
259-
model_hook_fx(outputs)
260-
else:
261-
self.warning_cache.warn(
262-
f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added."
263-
" Support for the old signature will be removed in v1.5", DeprecationWarning
264-
)
265-
model_hook_fx()
258+
model_hook_fx()
266259

267260
self.trainer._cache_logged_metrics()
268261

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -68,48 +68,6 @@ def on_train_epoch_end(self, outputs) -> None:
6868
trainer.fit(model)
6969

7070

71-
def test_on_val_epoch_end_outputs(tmpdir):
72-
73-
class CB(Callback):
74-
75-
def on_validation_epoch_end(self, trainer, pl_module, outputs):
76-
if trainer.running_sanity_check:
77-
assert len(outputs) == trainer.num_sanity_val_batches[0]
78-
else:
79-
assert len(outputs) == trainer.num_val_batches[0]
80-
81-
model = BoringModel()
82-
83-
trainer = Trainer(
84-
callbacks=CB(),
85-
default_root_dir=tmpdir,
86-
limit_train_batches=2,
87-
limit_val_batches=2,
88-
max_epochs=1,
89-
weights_summary=None,
90-
)
91-
92-
trainer.fit(model)
93-
94-
95-
def test_on_test_epoch_end_outputs(tmpdir):
96-
97-
class CB(Callback):
98-
99-
def on_test_epoch_end(self, trainer, pl_module, outputs):
100-
assert len(outputs) == trainer.num_test_batches[0]
101-
102-
model = BoringModel()
103-
104-
trainer = Trainer(
105-
callbacks=CB(),
106-
default_root_dir=tmpdir,
107-
weights_summary=None,
108-
)
109-
110-
trainer.test(model)
111-
112-
11371
def test_free_memory_on_eval_outputs(tmpdir):
11472

11573
class CB(Callback):

tests/callbacks/test_callbacks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
5858
call.on_validation_epoch_start(trainer, model),
5959
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
6060
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
61-
call.on_validation_epoch_end(trainer, model, ANY),
61+
call.on_validation_epoch_end(trainer, model),
6262
call.on_epoch_end(trainer, model),
6363
call.on_validation_end(trainer, model),
6464
call.on_sanity_check_end(trainer, model),
@@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
9090
call.on_validation_epoch_start(trainer, model),
9191
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
9292
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
93-
call.on_validation_epoch_end(trainer, model, ANY),
93+
call.on_validation_epoch_end(trainer, model),
9494
call.on_epoch_end(trainer, model),
9595
call.on_validation_end(trainer, model),
9696
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
@@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
128128
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
129129
call.on_test_batch_start(trainer, model, ANY, 1, 0),
130130
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
131-
call.on_test_epoch_end(trainer, model, ANY),
131+
call.on_test_epoch_end(trainer, model),
132132
call.on_epoch_end(trainer, model),
133133
call.on_test_end(trainer, model),
134134
call.teardown(trainer, model, 'test'),
@@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
163163
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
164164
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
165165
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
166-
call.on_validation_epoch_end(trainer, model, ANY),
166+
call.on_validation_epoch_end(trainer, model),
167167
call.on_epoch_end(trainer, model),
168168
call.on_validation_end(trainer, model),
169169
call.teardown(trainer, model, 'validate'),

tests/core/test_hooks.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

tests/deprecated_api/test_remove_1-5.py

Lines changed: 1 addition & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler
2828
from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache
2929
from tests.deprecated_api import no_deprecated_call
30-
from tests.helpers import BoringModel, BoringDataModule
30+
from tests.helpers import BoringDataModule, BoringModel
3131
from tests.helpers.utils import no_warning_call
3232

3333

@@ -213,96 +213,6 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
213213
ModelCheckpoint(dirpath=tmpdir, period=1)
214214

215215

216-
def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
217-
callback_warning_cache.clear()
218-
219-
class OldSignature(Callback):
220-
221-
def on_validation_epoch_end(self, trainer, pl_module): # noqa
222-
...
223-
224-
model = BoringModel()
225-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
226-
227-
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
228-
trainer.fit(model)
229-
230-
class OldSignatureModel(BoringModel):
231-
232-
def on_validation_epoch_end(self): # noqa
233-
...
234-
235-
model = OldSignatureModel()
236-
237-
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
238-
trainer.fit(model)
239-
240-
callback_warning_cache.clear()
241-
242-
class NewSignature(Callback):
243-
244-
def on_validation_epoch_end(self, trainer, pl_module, outputs):
245-
...
246-
247-
trainer.callbacks = [NewSignature()]
248-
with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."):
249-
trainer.fit(model)
250-
251-
class NewSignatureModel(BoringModel):
252-
253-
def on_validation_epoch_end(self, outputs):
254-
...
255-
256-
model = NewSignatureModel()
257-
with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."):
258-
trainer.fit(model)
259-
260-
261-
def test_v1_5_0_old_on_test_epoch_end(tmpdir):
262-
callback_warning_cache.clear()
263-
264-
class OldSignature(Callback):
265-
266-
def on_test_epoch_end(self, trainer, pl_module): # noqa
267-
...
268-
269-
model = BoringModel()
270-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
271-
272-
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
273-
trainer.test(model)
274-
275-
class OldSignatureModel(BoringModel):
276-
277-
def on_test_epoch_end(self): # noqa
278-
...
279-
280-
model = OldSignatureModel()
281-
282-
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
283-
trainer.test(model)
284-
285-
callback_warning_cache.clear()
286-
287-
class NewSignature(Callback):
288-
289-
def on_test_epoch_end(self, trainer, pl_module, outputs):
290-
...
291-
292-
trainer.callbacks = [NewSignature()]
293-
with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."):
294-
trainer.test(model)
295-
296-
class NewSignatureModel(BoringModel):
297-
298-
def on_test_epoch_end(self, outputs):
299-
...
300-
301-
model = NewSignatureModel()
302-
with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."):
303-
trainer.test(model)
304-
305-
306216
@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
307217
def test_v1_5_0_profiler_output_filename(tmpdir, cls):
308218
filepath = str(tmpdir / "test.txt")

tests/trainer/logging_/test_logger_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,10 @@ def _assert_epoch_end(self, stage):
681681
def on_train_epoch_end(self, outputs):
682682
self._assert_epoch_end('train')
683683

684-
def on_validation_epoch_end(self, outputs):
684+
def on_validation_epoch_end(self):
685685
self._assert_epoch_end('val')
686686

687-
def on_test_epoch_end(self, outputs):
687+
def on_test_epoch_end(self):
688688
self._assert_epoch_end('test')
689689

690690
def _assert_called(model, stage):

0 commit comments

Comments
 (0)