diff --git a/src/pytorch_lightning/utilities/fetching.py b/src/pytorch_lightning/utilities/fetching.py index ff7e6080bad7b..a4518e147da02 100644 --- a/src/pytorch_lightning/utilities/fetching.py +++ b/src/pytorch_lightning/utilities/fetching.py @@ -144,23 +144,18 @@ def _store_dataloader_iter_state( dataloader_iter.state.update(iter_name, state) @property - def loaders(self) -> List[DataLoader]: + def loaders(self) -> Any: if isinstance(self.dataloader, CombinedLoader): - loaders = self.dataloader.loaders - else: - loaders = [self.dataloader] - return loaders + return self.dataloader.loaders + return self.dataloader @property - def loader_iters(self) -> List[Iterator]: + def loader_iters(self) -> Any: if self.dataloader_iter is None: raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.") - if isinstance(self.dataloader, CombinedLoader): - loader_iters = self.dataloader_iter.loader_iters - else: - loader_iters = [self.dataloader_iter] - return loader_iters + return self.dataloader_iter.loader_iters + return self.dataloader_iter @property def state(self) -> List[MergedIteratorState]: diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index 47051d4efd098..5a5982ad009f9 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -700,16 +700,8 @@ def __len__(self): return self.len -# TODO: test with `RandomGeneratorGetItemDataset` @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.parametrize( - "dataset_class", - [ - SequentialGetItemDataset, - RandomGetItemDataset, - # RandomGeneratorGetItemDataset, - ], -) +@pytest.mark.parametrize("dataset_class", [SequentialGetItemDataset, RandomGetItemDataset]) @pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))]) @pytest.mark.parametrize("batch_size", [1, 2, 3]) def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): @@ -732,12 +724,11 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): _ = next(prefetch_iter) state: List[MergedIteratorState] = fetcher.state - assert len(state) == 1 - assert isinstance(state[0], MergedIteratorState) + assert isinstance(state, MergedIteratorState) assert len(fetcher.dataloader_iter.cache_states) == 1 if num_workers == 0: - assert state[0].state[0].num_batches_fetched == num_batches_fetched + assert state.state[0].num_batches_fetched == num_batches_fetched return state dataset, random_sampler = create_dataset_sampler() @@ -754,7 +745,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): # (A) capture the state after fetching 4 batches state = fetch(fetcher, prefetch_iter, 4) - state = deepcopy(state[0]) + state = deepcopy(state) # (B) simulate 2 additional batches batch05 = next(prefetch_iter) diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index e9ab01387f7f6..2d5e3954c7061 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -101,16 +101,14 @@ def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): def test_misconfiguration_error(): - fetcher = DataFetcher() + loader = DataLoader(range(10)) + fetcher.setup(loader) + assert fetcher.loaders == loader with pytest.raises( MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context." ): - loader = DataLoader(range(10)) - fetcher.setup(loader) - assert fetcher.loaders[0] == loader fetcher.loader_iters - iter(fetcher) assert fetcher.loader_iters