Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))


- Fixed intra-epoch evaluation outputs staying in memory when the respective `*_epoch_end` hook wasn't overridden ([#9261](https://github.com/PyTorchLightning/pytorch-lightning/pull/9261))


- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))


Expand Down
15 changes: 2 additions & 13 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Sequence, Union

from deprecate.utils import void
Expand All @@ -29,7 +28,7 @@ class EvaluationLoop(DataLoaderLoop):

def __init__(self):
super().__init__()
self.outputs = []
self.outputs: List[EPOCH_OUTPUT] = []
self.epoch_loop = EvaluationEpochLoop()

self._results = ResultCollection(training=False)
Expand Down Expand Up @@ -107,8 +106,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)

# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
self.outputs.append(dl_outputs)
self.outputs.append(dl_outputs)

if not self.trainer.sanity_checking:
# indicate the loop has run
Expand Down Expand Up @@ -165,8 +163,6 @@ def reload_evaluation_dataloaders(self) -> None:

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks"""
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()

assert self._results is not None
self._results.to(device=self.trainer.lightning_module.device)

Expand Down Expand Up @@ -210,13 +206,6 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)

def _should_track_batch_outputs_for_epoch_end(self) -> bool:
"""Whether the batch outputs should be stored for later usage"""
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden("test_epoch_end", model)
return is_overridden("validation_epoch_end", model)

def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Runs ``{validation/test}_epoch_end``"""
# inform logger the batch loop has finished
Expand Down
37 changes: 21 additions & 16 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
# limitations under the License.

from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Union
from functools import lru_cache
from typing import Any, Dict, Iterator, Optional, Union

from deprecate import void
from torch import Tensor

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT


class EvaluationEpochLoop(Loop):
Expand All @@ -37,7 +38,7 @@ def __init__(self) -> None:
self.dataloader: Optional[Iterator] = None
self._dl_max_batches: Optional[int] = None
self._num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.outputs: EPOCH_OUTPUT = []
self.batch_progress = Progress()
self.dataloader_iter: Optional[Iterator] = None

Expand Down Expand Up @@ -123,9 +124,12 @@ def advance(
self.trainer.logger_connector.update_eval_step_metrics()

# track epoch level outputs
self.outputs = self._track_output_for_epoch_end(self.outputs, output)
if self._should_track_batch_outputs_for_epoch_end():
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
Copy link

@jshin49 jshin49 Nov 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found an error here. Not sure if it is intended but when isinstance(output, defaultdict) == True is of type this causes a reproducible error with the following stack trace:

  File "scripts/train.py", line 103, in fine_tune
    trainer.fit(model, dataloaders["train"], dataloaders["dev"])
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
    self._run(model)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 922, in _run
    self._dispatch()
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _dispatch
    self.accelerator.start_training(self)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1000, in run_stage
    return self._run_train()
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1122, in _run_sanity_check
    self._evaluation_loop.run()
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 126, in advance
    output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/utilities/memory.py", line 44, in recursive_detach
    return apply_to_collection(in_dict, torch.Tensor, detach_and_move, to_cpu=to_cpu)
  File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/utilities/apply_func.py", line 109, in apply_to_collection
    return elem_type(OrderedDict(out))
TypeError: first argument must be callable or None

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before this didn't cause an error because the _track_output_for_epoch_end function that got removed in this PR seems to not call recursive_detach for defaultdict

Copy link

@jshin49 jshin49 Nov 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately this error is caused by Line 105 in the apply_to_collection function in utilities/apply_func.py
https://github.com/PyTorchLightning/pytorch-lightning/blob/1686aab5506ed6f4fee8683ce6cca711e62b5ef0/pytorch_lightning/utilities/apply_func.py#L94-L105

as basically this line is doing
defaultdict(OrderedDict([...]))

which would cause the same error:

TypeError: first argument must be callable or None

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a separate issue out of this so that others can search the same.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue left at #10308

if output is not None:
self.outputs.append(output)

def on_run_end(self) -> List[STEP_OUTPUT]:
def on_run_end(self) -> EPOCH_OUTPUT:
"""Returns the outputs of the whole run"""
outputs = self.outputs
# free memory
Expand Down Expand Up @@ -222,13 +226,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict

return step_kwargs

def _track_output_for_epoch_end(
self, outputs: List[STEP_OUTPUT], output: Optional[STEP_OUTPUT]
) -> List[STEP_OUTPUT]:
if output is not None:
if isinstance(output, dict):
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu:
output = output.cpu()
outputs.append(output)
return outputs
@lru_cache(1)
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
"""Whether the batch outputs should be stored for later usage"""
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden("test_epoch_end", model)
return is_overridden("validation_epoch_end", model)

def teardown(self) -> None:
# in case the model changes
self._should_track_batch_outputs_for_epoch_end.cache_clear()
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import torch
from torch.functional import Tensor
from torchmetrics import Metric
from typing_extensions import TypedDict

Expand All @@ -27,6 +26,7 @@
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -449,11 +449,7 @@ def log(
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
# no metrics should be logged with graphs
if not enable_graph:

def detach_fn(tensor: Tensor) -> Tensor:
return tensor.detach()

value = apply_to_collection(value, Tensor, detach_fn)
value = recursive_detach(value)

# move metrics to cpu on TPU.
if isinstance(value, torch.Tensor) and value.device.type == "xla":
Expand Down
26 changes: 11 additions & 15 deletions pytorch_lightning/utilities/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@
import shutil
import subprocess
import uuid
from typing import Any, Dict, Union
from typing import Any, Dict

import torch
from torch.nn import Module

_RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]]
from pytorch_lightning.utilities.apply_func import apply_to_collection


def recursive_detach(
in_dict: _RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor], Any]]:
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
"""Detach all tensors in `in_dict`.

May operate recursively if some of the values in `in_dict` are dictionaries
Expand All @@ -41,16 +39,14 @@ def recursive_detach(
Return:
out_dict: Dictionary with detached tensors
"""
out_dict = {}
for k, v in in_dict.items():
if isinstance(v, dict):
v = recursive_detach(v, to_cpu=to_cpu)
elif callable(getattr(v, "detach", None)):
v = v.detach()
if to_cpu:
v = v.cpu()
out_dict[k] = v
return out_dict

def detach_and_move(t: torch.Tensor, to_cpu: bool) -> torch.Tensor:
t = t.detach()
if to_cpu:
t = t.cpu()
return t

return apply_to_collection(in_dict, torch.Tensor, detach_and_move, to_cpu=to_cpu)


def is_oom_error(exception: BaseException) -> bool:
Expand Down
30 changes: 29 additions & 1 deletion tests/trainer/loops/test_evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import torch
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loops import EvaluationEpochLoop
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -101,3 +102,30 @@ def validation_step(self, batch, batch_idx):
torch.cuda.empty_cache()
trainer = Trainer(gpus=1, default_root_dir=tmpdir, fast_dev_run=2, move_metrics_to_cpu=True, weights_summary=None)
trainer.fit(BoringLargeBatchModel())


def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir):
did_assert = False

class TestModel(BoringModel):
def on_test_batch_end(self, outputs, *_):
# check `test_step` returns something
assert outputs is not None

class TestLoop(EvaluationEpochLoop):
def on_advance_end(self):
# should be empty
assert not self.outputs
# sanity check
nonlocal did_assert
did_assert = True
super().on_advance_end()

model = TestModel()
# make sure this hook is not overridden
model.test_epoch_end = LightningModule.test_epoch_end

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
trainer.test_loop.connect(TestLoop())
trainer.test(model)
assert did_assert