Skip to content

Commit 1b8336e

Browse files
committed
Move tracking epoch end outputs logic to the EvaluationEpochLoop (#9261)
1 parent 6d381a3 commit 1b8336e

File tree

6 files changed

+72
-56
lines changed

6 files changed

+72
-56
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
2222

2323

24+
- Fixed intra-epoch evaluation outputs staying in memory when the respective `*_epoch_end` hook wasn't overridden ([#9261](https://github.com/PyTorchLightning/pytorch-lightning/pull/9261))
25+
26+
2427
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))
2528

2629

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from typing import Any, List, Optional, Sequence, Union
1615

1716
from deprecate.utils import void
@@ -30,7 +29,7 @@ class EvaluationLoop(DataLoaderLoop):
3029

3130
def __init__(self):
3231
super().__init__()
33-
self.outputs = []
32+
self.outputs: List[EPOCH_OUTPUT] = []
3433
self.epoch_loop = EvaluationEpochLoop()
3534

3635
self._results = ResultCollection(training=False)
@@ -112,8 +111,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
112111
)
113112

114113
# store batch level output per dataloader
115-
if self.should_track_batch_outputs_for_epoch_end:
116-
self.outputs.append(dl_outputs)
114+
self.outputs.append(dl_outputs)
117115

118116
if not self.trainer.sanity_checking:
119117
# indicate the loop has run
@@ -174,8 +172,6 @@ def reload_evaluation_dataloaders(self) -> None:
174172

175173
def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
176174
"""Runs ``on_{validation/test}_start`` hooks"""
177-
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
178-
179175
assert self._results is not None
180176
self._results.to(device=self.trainer.lightning_module.device)
181177

@@ -224,13 +220,6 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
224220
else:
225221
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
226222

227-
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
228-
"""Whether the batch outputs should be stored for later usage"""
229-
model = self.trainer.lightning_module
230-
if self.trainer.testing:
231-
return is_overridden("test_epoch_end", model)
232-
return is_overridden("validation_epoch_end", model)
233-
234223
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
235224
"""Runs ``{validation/test}_epoch_end``"""
236225
# inform logger the batch loop has finished

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# limitations under the License.
1414

1515
from collections import OrderedDict
16-
from typing import Any, Dict, Iterator, List, Optional, Union
16+
from functools import lru_cache
17+
from typing import Any, Dict, Iterator, Optional, Union
1718

1819
from deprecate import void
19-
from torch import Tensor
2020

2121
from pytorch_lightning.loops.base import Loop
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2323
from pytorch_lightning.trainer.progress import Progress
2424
from pytorch_lightning.trainer.supporters import PredictionCollection
2525
from pytorch_lightning.utilities.memory import recursive_detach
26-
from pytorch_lightning.utilities.types import STEP_OUTPUT
26+
from pytorch_lightning.utilities.model_helpers import is_overridden
27+
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
2728

2829

2930
class EvaluationEpochLoop(Loop):
@@ -38,7 +39,7 @@ def __init__(self) -> None:
3839
self.dataloader: Optional[Iterator] = None
3940
self._dl_max_batches: Optional[int] = None
4041
self._num_dataloaders: Optional[int] = None
41-
self.outputs: List[STEP_OUTPUT] = []
42+
self.outputs: EPOCH_OUTPUT = []
4243
self.batch_progress = Progress()
4344

4445
@property
@@ -121,9 +122,12 @@ def advance(
121122
self.trainer.logger_connector.update_eval_step_metrics()
122123

123124
# track epoch level outputs
124-
self.outputs = self._track_output_for_epoch_end(self.outputs, output)
125+
if self._should_track_batch_outputs_for_epoch_end():
126+
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
127+
if output is not None:
128+
self.outputs.append(output)
125129

126-
def on_run_end(self) -> List[STEP_OUTPUT]:
130+
def on_run_end(self) -> EPOCH_OUTPUT:
127131
"""Returns the outputs of the whole run"""
128132
outputs = self.outputs
129133
# free memory
@@ -239,19 +243,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
239243

240244
return step_kwargs
241245

242-
def _track_output_for_epoch_end(
243-
self,
244-
outputs: List[Union[ResultCollection, Dict, Tensor]],
245-
output: Optional[Union[ResultCollection, Dict, Tensor]],
246-
) -> List[Union[ResultCollection, Dict, Tensor]]:
247-
if output is not None:
248-
if isinstance(output, ResultCollection):
249-
output = output.detach()
250-
if self.trainer.move_metrics_to_cpu:
251-
output = output.cpu()
252-
elif isinstance(output, dict):
253-
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
254-
elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu:
255-
output = output.cpu()
256-
outputs.append(output)
257-
return outputs
246+
@lru_cache(1)
247+
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
248+
"""Whether the batch outputs should be stored for later usage"""
249+
model = self.trainer.lightning_module
250+
if self.trainer.testing:
251+
return is_overridden("test_epoch_end", model)
252+
return is_overridden("validation_epoch_end", model)
253+
254+
def teardown(self) -> None:
255+
# in case the model changes
256+
self._should_track_batch_outputs_for_epoch_end.cache_clear()

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1818

1919
import torch
20-
from torch.functional import Tensor
2120
from torchmetrics import Metric
2221

2322
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
@@ -26,6 +25,7 @@
2625
from pytorch_lightning.utilities.data import extract_batch_size
2726
from pytorch_lightning.utilities.enums import LightningEnum
2827
from pytorch_lightning.utilities.exceptions import MisconfigurationException
28+
from pytorch_lightning.utilities.memory import recursive_detach
2929
from pytorch_lightning.utilities.metrics import metrics_to_scalars
3030
from pytorch_lightning.utilities.warnings import WarningCache
3131

@@ -436,11 +436,7 @@ def log(
436436
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
437437
# no metrics should be logged with graphs
438438
if not enable_graph:
439-
440-
def detach_fn(tensor: Tensor) -> Tensor:
441-
return tensor.detach()
442-
443-
value = apply_to_collection(value, Tensor, detach_fn)
439+
value = recursive_detach(value)
444440

445441
# move metrics to cpu on TPU.
446442
if isinstance(value, torch.Tensor) and value.device.type == "xla":

pytorch_lightning/utilities/memory.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414

1515
import gc
16+
from typing import Any
1617

1718
import torch
1819

20+
from pytorch_lightning.utilities.apply_func import apply_to_collection
1921

20-
def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict:
22+
23+
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
2124
"""Detach all tensors in `in_dict`.
2225
2326
May operate recursively if some of the values in `in_dict` are dictionaries
@@ -31,19 +34,17 @@ def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict:
3134
Return:
3235
out_dict: Dictionary with detached tensors
3336
"""
34-
out_dict = {}
35-
for k, v in in_dict.items():
36-
if isinstance(v, dict):
37-
v = recursive_detach(v, to_cpu=to_cpu)
38-
elif callable(getattr(v, "detach", None)):
39-
v = v.detach()
40-
if to_cpu:
41-
v = v.cpu()
42-
out_dict[k] = v
43-
return out_dict
44-
45-
46-
def is_oom_error(exception):
37+
38+
def detach_and_move(t: torch.Tensor, to_cpu: bool) -> torch.Tensor:
39+
t = t.detach()
40+
if to_cpu:
41+
t = t.cpu()
42+
return t
43+
44+
return apply_to_collection(in_dict, torch.Tensor, detach_and_move, to_cpu=to_cpu)
45+
46+
47+
def is_oom_error(exception: BaseException) -> bool:
4748
return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception)
4849

4950

tests/trainer/loops/test_evaluation_loop.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import torch
1717
from torch.utils.data import DataLoader
1818

19-
from pytorch_lightning import Trainer
19+
from pytorch_lightning import LightningModule, Trainer
20+
from pytorch_lightning.loops import EvaluationEpochLoop
2021
from tests.helpers.boring_model import BoringModel, RandomDataset
2122
from tests.helpers.runif import RunIf
2223

@@ -101,3 +102,30 @@ def validation_step(self, batch, batch_idx):
101102
torch.cuda.empty_cache()
102103
trainer = Trainer(gpus=1, default_root_dir=tmpdir, fast_dev_run=2, move_metrics_to_cpu=True, weights_summary=None)
103104
trainer.fit(BoringLargeBatchModel())
105+
106+
107+
def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir):
108+
did_assert = False
109+
110+
class TestModel(BoringModel):
111+
def on_test_batch_end(self, outputs, *_):
112+
# check `test_step` returns something
113+
assert outputs is not None
114+
115+
class TestLoop(EvaluationEpochLoop):
116+
def on_advance_end(self):
117+
# should be empty
118+
assert not self.outputs
119+
# sanity check
120+
nonlocal did_assert
121+
did_assert = True
122+
super().on_advance_end()
123+
124+
model = TestModel()
125+
# make sure this hook is not overridden
126+
model.test_epoch_end = LightningModule.test_epoch_end
127+
128+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
129+
trainer.test_loop.connect(TestLoop())
130+
trainer.test(model)
131+
assert did_assert

0 commit comments

Comments
 (0)