From 376753a40fba6624e1698c6a218347712e103750 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 19 May 2022 21:24:13 +0200 Subject: [PATCH 1/5] Manually reset sized DALI iterators --- pytorch_lightning/utilities/fetching.py | 26 ++++++------ pytorch_lightning/utilities/imports.py | 1 + tests/utilities/test_fetching.py | 53 +++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index ff7e6080bad7b..6e08e389e5e5e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -32,7 +32,7 @@ ) from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _NVIDIA_DALI_AVAILABLE def _profile_nothing() -> None: @@ -144,23 +144,18 @@ def _store_dataloader_iter_state( dataloader_iter.state.update(iter_name, state) @property - def loaders(self) -> List[DataLoader]: + def loaders(self) -> Any: if isinstance(self.dataloader, CombinedLoader): - loaders = self.dataloader.loaders - else: - loaders = [self.dataloader] - return loaders + return self.dataloader.loaders + return self.dataloader @property - def loader_iters(self) -> List[Iterator]: + def loader_iters(self) -> Any: if self.dataloader_iter is None: raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.") - if isinstance(self.dataloader, CombinedLoader): - loader_iters = self.dataloader_iter.loader_iters - else: - loader_iters = [self.dataloader_iter] - return loader_iters + return self.dataloader_iter.loader_iters + return self.dataloader_iter @property def state(self) -> List[MergedIteratorState]: @@ -286,7 +281,12 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True - self.done = self.fetched >= len(dataloader) + if self.fetched >= len(dataloader): + self.done = True + if _NVIDIA_DALI_AVAILABLE: + from nvidia.dali.plugin.pytorch import DALIGenericIterator + + apply_to_collection(self.loaders, DALIGenericIterator, DALIGenericIterator.reset) self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 83850f8e7415a..68311ea33d1d2 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -149,6 +149,7 @@ def __repr__(self) -> str: _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() _NEPTUNE_AVAILABLE = _package_available("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") +_NVIDIA_DALI_AVAILABLE = _RequirementAvailable("nvidia-dali-cuda110") _OMEGACONF_AVAILABLE = _package_available("omegaconf") _POPTORCH_AVAILABLE = _package_available("poptorch") _PSUTIL_AVAILABLE = _package_available("psutil") diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index b536a61036f7b..84a0eb4808a28 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from time import time from typing import Any, Iterator from unittest import mock +from unittest.mock import Mock import pytest import torch @@ -572,3 +574,54 @@ def training_step(self, dataloader_iter): durations = profiler.recorded_durations[key] assert len(durations) == 2 # 2 polls in training_step assert all(d > 0 for d in durations) + + +@pytest.mark.parametrize("prefetch_batches", (0, 1)) +def test_dali_like_iterator(monkeypatch, prefetch_batches): + class MockDALIGenericIterator(Iterator): + def __init__(self, size: int): + self.counter = 0 + self.size = size + + def __next__(self): + if self.counter >= self.size: + # the DALI iterator needs one extra fetch call to reset the state and stop + self.reset() + raise StopIteration + self.counter += 1 + return self.counter # can be anything really + + def reset(self): + self.counter = 0 + + def __len__(self): + return self.size + + import pytorch_lightning.utilities.fetching as fetching + + # the batches without any of our fetching logic + expected = [(e, b) for e in range(3) for b in MockDALIGenericIterator(3)] + + nvidia_dali_module = "nvidia.dali.plugin.pytorch" + # need different mocking logic based on whether the library is installed + if fetching._NVIDIA_DALI_AVAILABLE: + patch = mock.patch(f"{nvidia_dali_module}.DALIGenericIterator", new=MockDALIGenericIterator) + patch.__enter__() + cleanup_modules = False + else: + monkeypatch.setattr(fetching, "_NVIDIA_DALI_AVAILABLE", True) + nvidia_mock = Mock() + nvidia_mock.DALIGenericIterator = MockDALIGenericIterator + sys.modules[nvidia_dali_module] = nvidia_mock + cleanup_modules = True + + fetcher = DataFetcher(prefetch_batches) + fetcher.setup(MockDALIGenericIterator(3)) + actual = [(e, b) for e in range(3) for b in fetcher] + + assert actual == expected + + if cleanup_modules: + del sys.modules[nvidia_dali_module] + else: + patch.__exit__(None, None, None) From 8b532b17a09bd4228e8b99eedd9739640cdb5f50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 19 May 2022 21:30:39 +0200 Subject: [PATCH 2/5] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a35fdd2a63d1..8b6066c7dc126 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -223,6 +223,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Avoid redundant callback restore warning while tuning ([#13026](https://github.com/PyTorchLightning/pytorch-lightning/pull/13026)) +- Fixed nvidia-DALI iterators skipping every second epoch entirely ([#13111](https://github.com/PyTorchLightning/pytorch-lightning/pull/13111)) + + - Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types ([#13028](https://github.com/PyTorchLightning/pytorch-lightning/pull/13028)) From 8c03a5ae57b32d31aa0d29696a67f331567b2180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 19 May 2022 21:36:38 +0200 Subject: [PATCH 3/5] Fix broken test --- tests/utilities/test_auto_restart.py | 17 ++++------------- tests/utilities/test_fetching.py | 8 +++----- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 2849aee531f18..eb5702422f4f9 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -751,16 +751,8 @@ def __len__(self): return self.len -# TODO: test with `RandomGeneratorGetItemDataset` @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.parametrize( - "dataset_class", - [ - SequentialGetItemDataset, - RandomGetItemDataset, - # RandomGeneratorGetItemDataset, - ], -) +@pytest.mark.parametrize("dataset_class", [SequentialGetItemDataset, RandomGetItemDataset]) @pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))]) @pytest.mark.parametrize("batch_size", [1, 2, 3]) def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): @@ -783,12 +775,11 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): _ = next(prefetch_iter) state: List[MergedIteratorState] = fetcher.state - assert len(state) == 1 - assert isinstance(state[0], MergedIteratorState) + assert isinstance(state, MergedIteratorState) assert len(fetcher.dataloader_iter.cache_states) == 1 if num_workers == 0: - assert state[0].state[0].num_batches_fetched == num_batches_fetched + assert state.state[0].num_batches_fetched == num_batches_fetched return state dataset, random_sampler = create_dataset_sampler() @@ -805,7 +796,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): # (A) capture the state after fetching 4 batches state = fetch(fetcher, prefetch_iter, 4) - state = deepcopy(state[0]) + state = deepcopy(state) # (B) simulate 2 additional batches batch05 = next(prefetch_iter) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 84a0eb4808a28..4919a39db907c 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -103,16 +103,14 @@ def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): def test_misconfiguration_error(): - fetcher = DataFetcher() + loader = DataLoader(range(10)) + fetcher.setup(loader) + assert fetcher.loaders == loader with pytest.raises( MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context." ): - loader = DataLoader(range(10)) - fetcher.setup(loader) - assert fetcher.loaders[0] == loader fetcher.loader_iters - iter(fetcher) assert fetcher.loader_iters From 15e9e157030060fb4937b4a4fd7c7e3cb6a80137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 13 Jul 2022 16:42:19 +0200 Subject: [PATCH 4/5] Remove DALI code --- src/pytorch_lightning/utilities/fetching.py | 9 +--- src/pytorch_lightning/utilities/imports.py | 1 - .../tests_pytorch/utilities/test_fetching.py | 51 ------------------- 3 files changed, 2 insertions(+), 59 deletions(-) diff --git a/src/pytorch_lightning/utilities/fetching.py b/src/pytorch_lightning/utilities/fetching.py index 6e08e389e5e5e..a4518e147da02 100644 --- a/src/pytorch_lightning/utilities/fetching.py +++ b/src/pytorch_lightning/utilities/fetching.py @@ -32,7 +32,7 @@ ) from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training, _NVIDIA_DALI_AVAILABLE +from pytorch_lightning.utilities.imports import _fault_tolerant_training def _profile_nothing() -> None: @@ -281,12 +281,7 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True - if self.fetched >= len(dataloader): - self.done = True - if _NVIDIA_DALI_AVAILABLE: - from nvidia.dali.plugin.pytorch import DALIGenericIterator - - apply_to_collection(self.loaders, DALIGenericIterator, DALIGenericIterator.reset) + self.done = self.fetched >= len(dataloader) self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index f39a0965c603f..6fbeda8a7c600 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -144,7 +144,6 @@ def __repr__(self) -> str: _KINETO_AVAILABLE = torch.profiler.kineto_available() _NEPTUNE_AVAILABLE = _package_available("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") -_NVIDIA_DALI_AVAILABLE = _RequirementAvailable("nvidia-dali-cuda110") _OMEGACONF_AVAILABLE = _package_available("omegaconf") _POPTORCH_AVAILABLE = _package_available("poptorch") _PSUTIL_AVAILABLE = _package_available("psutil") diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index 6e01082d3741d..544bb1f791d62 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -574,54 +574,3 @@ def training_step(self, dataloader_iter): durations = profiler.recorded_durations[key] assert len(durations) == 2 # 2 polls in training_step assert all(d > 0 for d in durations) - - -@pytest.mark.parametrize("prefetch_batches", (0, 1)) -def test_dali_like_iterator(monkeypatch, prefetch_batches): - class MockDALIGenericIterator(Iterator): - def __init__(self, size: int): - self.counter = 0 - self.size = size - - def __next__(self): - if self.counter >= self.size: - # the DALI iterator needs one extra fetch call to reset the state and stop - self.reset() - raise StopIteration - self.counter += 1 - return self.counter # can be anything really - - def reset(self): - self.counter = 0 - - def __len__(self): - return self.size - - import pytorch_lightning.utilities.fetching as fetching - - # the batches without any of our fetching logic - expected = [(e, b) for e in range(3) for b in MockDALIGenericIterator(3)] - - nvidia_dali_module = "nvidia.dali.plugin.pytorch" - # need different mocking logic based on whether the library is installed - if fetching._NVIDIA_DALI_AVAILABLE: - patch = mock.patch(f"{nvidia_dali_module}.DALIGenericIterator", new=MockDALIGenericIterator) - patch.__enter__() - cleanup_modules = False - else: - monkeypatch.setattr(fetching, "_NVIDIA_DALI_AVAILABLE", True) - nvidia_mock = Mock() - nvidia_mock.DALIGenericIterator = MockDALIGenericIterator - sys.modules[nvidia_dali_module] = nvidia_mock - cleanup_modules = True - - fetcher = DataFetcher(prefetch_batches) - fetcher.setup(MockDALIGenericIterator(3)) - actual = [(e, b) for e in range(3) for b in fetcher] - - assert actual == expected - - if cleanup_modules: - del sys.modules[nvidia_dali_module] - else: - patch.__exit__(None, None, None) From 4edb65ff0f6e637c394582e98c656dca843e09b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 14 Jul 2022 13:32:44 +0200 Subject: [PATCH 5/5] Unused imports --- tests/tests_pytorch/utilities/test_fetching.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index 544bb1f791d62..2d5e3954c7061 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys from time import time from typing import Any, Iterator from unittest import mock -from unittest.mock import Mock import pytest import torch