Skip to content

Commit f745aa9

Browse files
authored
Move tracking epoch end outputs logic to the EvaluationEpochLoop (#9261)
1 parent b91747e commit f745aa9

File tree

6 files changed

+68
-51
lines changed

6 files changed

+68
-51
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
277277
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
278278

279279

280+
- Fixed intra-epoch evaluation outputs staying in memory when the respective `*_epoch_end` hook wasn't overridden ([#9261](https://github.com/PyTorchLightning/pytorch-lightning/pull/9261))
281+
282+
280283
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))
281284

282285

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from typing import Any, List, Optional, Sequence, Union
1615

1716
from deprecate.utils import void
@@ -29,7 +28,7 @@ class EvaluationLoop(DataLoaderLoop):
2928

3029
def __init__(self):
3130
super().__init__()
32-
self.outputs = []
31+
self.outputs: List[EPOCH_OUTPUT] = []
3332
self.epoch_loop = EvaluationEpochLoop()
3433

3534
self._results = ResultCollection(training=False)
@@ -107,8 +106,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
107106
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
108107

109108
# store batch level output per dataloader
110-
if self.should_track_batch_outputs_for_epoch_end:
111-
self.outputs.append(dl_outputs)
109+
self.outputs.append(dl_outputs)
112110

113111
if not self.trainer.sanity_checking:
114112
# indicate the loop has run
@@ -165,8 +163,6 @@ def reload_evaluation_dataloaders(self) -> None:
165163

166164
def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
167165
"""Runs ``on_{validation/test}_start`` hooks"""
168-
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
169-
170166
assert self._results is not None
171167
self._results.to(device=self.trainer.lightning_module.device)
172168

@@ -210,13 +206,6 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
210206
else:
211207
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
212208

213-
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
214-
"""Whether the batch outputs should be stored for later usage"""
215-
model = self.trainer.lightning_module
216-
if self.trainer.testing:
217-
return is_overridden("test_epoch_end", model)
218-
return is_overridden("validation_epoch_end", model)
219-
220209
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
221210
"""Runs ``{validation/test}_epoch_end``"""
222211
# inform logger the batch loop has finished

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# limitations under the License.
1414

1515
from collections import OrderedDict
16-
from typing import Any, Dict, Iterator, List, Optional, Union
16+
from functools import lru_cache
17+
from typing import Any, Dict, Iterator, Optional, Union
1718

1819
from deprecate import void
19-
from torch import Tensor
2020

2121
from pytorch_lightning.loops.base import Loop
2222
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
2323
from pytorch_lightning.trainer.progress import Progress
2424
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
2525
from pytorch_lightning.utilities.memory import recursive_detach
26-
from pytorch_lightning.utilities.types import STEP_OUTPUT
26+
from pytorch_lightning.utilities.model_helpers import is_overridden
27+
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
2728

2829

2930
class EvaluationEpochLoop(Loop):
@@ -37,7 +38,7 @@ def __init__(self) -> None:
3738
self.dataloader: Optional[Iterator] = None
3839
self._dl_max_batches: Optional[int] = None
3940
self._num_dataloaders: Optional[int] = None
40-
self.outputs: List[STEP_OUTPUT] = []
41+
self.outputs: EPOCH_OUTPUT = []
4142
self.batch_progress = Progress()
4243
self.dataloader_iter: Optional[Iterator] = None
4344

@@ -123,9 +124,12 @@ def advance(
123124
self.trainer.logger_connector.update_eval_step_metrics()
124125

125126
# track epoch level outputs
126-
self.outputs = self._track_output_for_epoch_end(self.outputs, output)
127+
if self._should_track_batch_outputs_for_epoch_end():
128+
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
129+
if output is not None:
130+
self.outputs.append(output)
127131

128-
def on_run_end(self) -> List[STEP_OUTPUT]:
132+
def on_run_end(self) -> EPOCH_OUTPUT:
129133
"""Returns the outputs of the whole run"""
130134
outputs = self.outputs
131135
# free memory
@@ -222,13 +226,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
222226

223227
return step_kwargs
224228

225-
def _track_output_for_epoch_end(
226-
self, outputs: List[STEP_OUTPUT], output: Optional[STEP_OUTPUT]
227-
) -> List[STEP_OUTPUT]:
228-
if output is not None:
229-
if isinstance(output, dict):
230-
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
231-
elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu:
232-
output = output.cpu()
233-
outputs.append(output)
234-
return outputs
229+
@lru_cache(1)
230+
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
231+
"""Whether the batch outputs should be stored for later usage"""
232+
model = self.trainer.lightning_module
233+
if self.trainer.testing:
234+
return is_overridden("test_epoch_end", model)
235+
return is_overridden("validation_epoch_end", model)
236+
237+
def teardown(self) -> None:
238+
# in case the model changes
239+
self._should_track_batch_outputs_for_epoch_end.cache_clear()

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1818

1919
import torch
20-
from torch.functional import Tensor
2120
from torchmetrics import Metric
2221
from typing_extensions import TypedDict
2322

@@ -27,6 +26,7 @@
2726
from pytorch_lightning.utilities.data import extract_batch_size
2827
from pytorch_lightning.utilities.enums import LightningEnum
2928
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29+
from pytorch_lightning.utilities.memory import recursive_detach
3030
from pytorch_lightning.utilities.metrics import metrics_to_scalars
3131
from pytorch_lightning.utilities.warnings import WarningCache
3232

@@ -449,11 +449,7 @@ def log(
449449
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
450450
# no metrics should be logged with graphs
451451
if not enable_graph:
452-
453-
def detach_fn(tensor: Tensor) -> Tensor:
454-
return tensor.detach()
455-
456-
value = apply_to_collection(value, Tensor, detach_fn)
452+
value = recursive_detach(value)
457453

458454
# move metrics to cpu on TPU.
459455
if isinstance(value, torch.Tensor) and value.device.type == "xla":

pytorch_lightning/utilities/memory.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,15 @@
1717
import shutil
1818
import subprocess
1919
import uuid
20-
from typing import Any, Dict, Union
20+
from typing import Any, Dict
2121

2222
import torch
2323
from torch.nn import Module
2424

25-
_RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]]
25+
from pytorch_lightning.utilities.apply_func import apply_to_collection
2626

2727

28-
def recursive_detach(
29-
in_dict: _RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False
30-
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor], Any]]:
28+
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
3129
"""Detach all tensors in `in_dict`.
3230
3331
May operate recursively if some of the values in `in_dict` are dictionaries
@@ -41,16 +39,14 @@ def recursive_detach(
4139
Return:
4240
out_dict: Dictionary with detached tensors
4341
"""
44-
out_dict = {}
45-
for k, v in in_dict.items():
46-
if isinstance(v, dict):
47-
v = recursive_detach(v, to_cpu=to_cpu)
48-
elif callable(getattr(v, "detach", None)):
49-
v = v.detach()
50-
if to_cpu:
51-
v = v.cpu()
52-
out_dict[k] = v
53-
return out_dict
42+
43+
def detach_and_move(t: torch.Tensor, to_cpu: bool) -> torch.Tensor:
44+
t = t.detach()
45+
if to_cpu:
46+
t = t.cpu()
47+
return t
48+
49+
return apply_to_collection(in_dict, torch.Tensor, detach_and_move, to_cpu=to_cpu)
5450

5551

5652
def is_oom_error(exception: BaseException) -> bool:

tests/trainer/loops/test_evaluation_loop.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import torch
1717
from torch.utils.data import DataLoader
1818

19-
from pytorch_lightning import Trainer
19+
from pytorch_lightning import LightningModule, Trainer
20+
from pytorch_lightning.loops import EvaluationEpochLoop
2021
from tests.helpers.boring_model import BoringModel, RandomDataset
2122
from tests.helpers.runif import RunIf
2223

@@ -101,3 +102,30 @@ def validation_step(self, batch, batch_idx):
101102
torch.cuda.empty_cache()
102103
trainer = Trainer(gpus=1, default_root_dir=tmpdir, fast_dev_run=2, move_metrics_to_cpu=True, weights_summary=None)
103104
trainer.fit(BoringLargeBatchModel())
105+
106+
107+
def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir):
108+
did_assert = False
109+
110+
class TestModel(BoringModel):
111+
def on_test_batch_end(self, outputs, *_):
112+
# check `test_step` returns something
113+
assert outputs is not None
114+
115+
class TestLoop(EvaluationEpochLoop):
116+
def on_advance_end(self):
117+
# should be empty
118+
assert not self.outputs
119+
# sanity check
120+
nonlocal did_assert
121+
did_assert = True
122+
super().on_advance_end()
123+
124+
model = TestModel()
125+
# make sure this hook is not overridden
126+
model.test_epoch_end = LightningModule.test_epoch_end
127+
128+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
129+
trainer.test_loop.connect(TestLoop())
130+
trainer.test(model)
131+
assert did_assert

0 commit comments

Comments
 (0)