From 4459d25791212329cf0c1cb8b5a8cfadbcec81ff Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 00:32:42 +0200 Subject: [PATCH 1/6] Move tracking epoch end outputs logic to the `EvaluationEpochLoop` --- .../loops/dataloader/evaluation_loop.py | 15 ++------- .../loops/epoch/evaluation_epoch_loop.py | 33 ++++++++++--------- .../connectors/logger_connector/result.py | 8 ++--- pytorch_lightning/utilities/memory.py | 24 ++++++-------- 4 files changed, 31 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 68b75b68eb91b..65babb7bfd448 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, List, Optional, Sequence, Union from deprecate.utils import void @@ -29,7 +28,7 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self): super().__init__() - self.outputs = [] + self.outputs: List[EPOCH_OUTPUT] = [] self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) @@ -107,8 +106,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) # store batch level output per dataloader - if self.should_track_batch_outputs_for_epoch_end: - self.outputs.append(dl_outputs) + self.outputs.append(dl_outputs) if not self.trainer.sanity_checking: # indicate the loop has run @@ -165,8 +163,6 @@ def reload_evaluation_dataloaders(self) -> None: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks""" - self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() - assert self._results is not None self._results.to(device=self.trainer.lightning_module.device) @@ -210,13 +206,6 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs) - def _should_track_batch_outputs_for_epoch_end(self) -> bool: - """Whether the batch outputs should be stored for later usage""" - model = self.trainer.lightning_module - if self.trainer.testing: - return is_overridden("test_epoch_end", model) - return is_overridden("validation_epoch_end", model) - def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """Runs ``{validation/test}_epoch_end``""" # inform logger the batch loop has finished diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 8b38d0f898aaf..21be86a6410fa 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -13,17 +13,18 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Union +from functools import lru_cache +from typing import Any, Dict, Iterator, Optional, Union from deprecate import void -from torch import Tensor from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT class EvaluationEpochLoop(Loop): @@ -37,7 +38,7 @@ def __init__(self) -> None: self.dataloader: Optional[Iterator] = None self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None - self.outputs: List[STEP_OUTPUT] = [] + self.outputs: EPOCH_OUTPUT = [] self.batch_progress = Progress() self.dataloader_iter: Optional[Iterator] = None @@ -123,9 +124,12 @@ def advance( self.trainer.logger_connector.update_eval_step_metrics() # track epoch level outputs - self.outputs = self._track_output_for_epoch_end(self.outputs, output) + if self._should_track_batch_outputs_for_epoch_end(): + output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu) + if output is not None: + self.outputs.append(output) - def on_run_end(self) -> List[STEP_OUTPUT]: + def on_run_end(self) -> EPOCH_OUTPUT: """Returns the outputs of the whole run""" outputs = self.outputs # free memory @@ -222,13 +226,10 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict return step_kwargs - def _track_output_for_epoch_end( - self, outputs: List[STEP_OUTPUT], output: Optional[STEP_OUTPUT] - ) -> List[STEP_OUTPUT]: - if output is not None: - if isinstance(output, dict): - output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu) - elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu: - output = output.cpu() - outputs.append(output) - return outputs + @lru_cache(1) + def _should_track_batch_outputs_for_epoch_end(self) -> bool: + """Whether the batch outputs should be stored for later usage""" + model = self.trainer.lightning_module + if self.trainer.testing: + return is_overridden("test_epoch_end", model) + return is_overridden("validation_epoch_end", model) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7b3a048314f09..9443ac6670bde 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -17,7 +17,6 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import torch -from torch.functional import Tensor from torchmetrics import Metric from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin @@ -26,6 +25,7 @@ from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.warnings import WarningCache @@ -437,11 +437,7 @@ def log( """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs if not enable_graph: - - def detach_fn(tensor: Tensor) -> Tensor: - return tensor.detach() - - value = apply_to_collection(value, Tensor, detach_fn) + value = recursive_detach(value) # move metrics to cpu on TPU. if isinstance(value, torch.Tensor) and value.device.type == "xla": diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index ad56c95619928..a5319a522098b 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -22,12 +22,10 @@ import torch from torch.nn import Module -_RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]] +from pytorch_lightning.utilities.apply_func import apply_to_collection -def recursive_detach( - in_dict: _RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False -) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor], Any]]: +def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries @@ -41,16 +39,14 @@ def recursive_detach( Return: out_dict: Dictionary with detached tensors """ - out_dict = {} - for k, v in in_dict.items(): - if isinstance(v, dict): - v = recursive_detach(v, to_cpu=to_cpu) - elif callable(getattr(v, "detach", None)): - v = v.detach() - if to_cpu: - v = v.cpu() - out_dict[k] = v - return out_dict + + def detach_and_move(t: torch.Tensor, to_cpu: bool) -> torch.Tensor: + t = t.detach() + if to_cpu: + t = t.cpu() + return t + + return apply_to_collection(in_dict, torch.Tensor, detach_and_move, to_cpu=to_cpu) def is_oom_error(exception: BaseException) -> bool: From f32685bc7464e18a5699d08d69766b825757fe46 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 00:56:41 +0200 Subject: [PATCH 2/6] cache clear --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 21be86a6410fa..9f8f77e806b87 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -233,3 +233,7 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool: if self.trainer.testing: return is_overridden("test_epoch_end", model) return is_overridden("validation_epoch_end", model) + + def teardown(self) -> None: + # in case the model changes + self._should_track_batch_outputs_for_epoch_end.cache_clear() From 6d2488c67458a0705ba9ded74e0ecfe0fc837654 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 01:11:36 +0200 Subject: [PATCH 3/6] Add test --- tests/trainer/loops/test_evaluation_loop.py | 30 ++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index d7acd7e65727e..8dc87b7e75591 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -16,7 +16,8 @@ import torch from torch.utils.data import DataLoader -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.loops import EvaluationEpochLoop from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -101,3 +102,30 @@ def validation_step(self, batch, batch_idx): torch.cuda.empty_cache() trainer = Trainer(gpus=1, default_root_dir=tmpdir, fast_dev_run=2, move_metrics_to_cpu=True, weights_summary=None) trainer.fit(BoringLargeBatchModel()) + + +def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir): + did_assert = False + + class TestModel(BoringModel): + def on_test_batch_end(self, outputs, *_): + # check `test_step` returns something + assert outputs is not None + + class TestLoop(EvaluationEpochLoop): + def on_advance_end(self): + super().on_advance_end() + # should be empty + assert not self.outputs + # sanity check + nonlocal did_assert + did_assert = True + + model = TestModel() + # make sure this hook is not overridden + model.test_epoch_end = LightningModule.test_epoch_end + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3) + trainer.test_loop.connect(TestLoop()) + trainer.test(model) + assert did_assert From 077c1b7f601a2fbd3c75d1e4f8fd8b3a518ae57c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 01:18:43 +0200 Subject: [PATCH 4/6] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bda9ff766ad9..7ea829e2d8a1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) +- Fixed intra-epoch evaluation outputs staying in memory when the respective `*_epoch_end` hook wasn't overridden ([#9261](https://github.com/PyTorchLightning/pytorch-lightning/pull/9261)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) From b49420699467426b1a7fddb7b30ccc49bb197c42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 02:00:48 +0200 Subject: [PATCH 5/6] move on_advance_end call just in case --- tests/trainer/loops/test_evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 8dc87b7e75591..9d4af2f393aee 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -114,12 +114,12 @@ def on_test_batch_end(self, outputs, *_): class TestLoop(EvaluationEpochLoop): def on_advance_end(self): - super().on_advance_end() # should be empty assert not self.outputs # sanity check nonlocal did_assert did_assert = True + super().on_advance_end() model = TestModel() # make sure this hook is not overridden From e78f6bdd75a430d5201ab587e61fe91fafc2d6a4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 3 Sep 2021 02:08:34 +0200 Subject: [PATCH 6/6] Fix pep8 --- pytorch_lightning/utilities/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index a29dd281c3144..02636ba54e37d 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -17,7 +17,7 @@ import shutil import subprocess import uuid -from typing import Any, Dict, Union +from typing import Any, Dict import torch from torch.nn import Module