Skip to content

Commit bbd364a

Browse files
authored
Simplify fetching's loader types (#13111)
1 parent d786985 commit bbd364a

File tree

3 files changed

+13
-29
lines changed

3 files changed

+13
-29
lines changed

src/pytorch_lightning/utilities/fetching.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,23 +144,18 @@ def _store_dataloader_iter_state(
144144
dataloader_iter.state.update(iter_name, state)
145145

146146
@property
147-
def loaders(self) -> List[DataLoader]:
147+
def loaders(self) -> Any:
148148
if isinstance(self.dataloader, CombinedLoader):
149-
loaders = self.dataloader.loaders
150-
else:
151-
loaders = [self.dataloader]
152-
return loaders
149+
return self.dataloader.loaders
150+
return self.dataloader
153151

154152
@property
155-
def loader_iters(self) -> List[Iterator]:
153+
def loader_iters(self) -> Any:
156154
if self.dataloader_iter is None:
157155
raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.")
158-
159156
if isinstance(self.dataloader, CombinedLoader):
160-
loader_iters = self.dataloader_iter.loader_iters
161-
else:
162-
loader_iters = [self.dataloader_iter]
163-
return loader_iters
157+
return self.dataloader_iter.loader_iters
158+
return self.dataloader_iter
164159

165160
@property
166161
def state(self) -> List[MergedIteratorState]:

tests/tests_pytorch/utilities/test_auto_restart.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -700,16 +700,8 @@ def __len__(self):
700700
return self.len
701701

702702

703-
# TODO: test with `RandomGeneratorGetItemDataset`
704703
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
705-
@pytest.mark.parametrize(
706-
"dataset_class",
707-
[
708-
SequentialGetItemDataset,
709-
RandomGetItemDataset,
710-
# RandomGeneratorGetItemDataset,
711-
],
712-
)
704+
@pytest.mark.parametrize("dataset_class", [SequentialGetItemDataset, RandomGetItemDataset])
713705
@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))])
714706
@pytest.mark.parametrize("batch_size", [1, 2, 3])
715707
def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size):
@@ -732,12 +724,11 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched):
732724
_ = next(prefetch_iter)
733725

734726
state: List[MergedIteratorState] = fetcher.state
735-
assert len(state) == 1
736-
assert isinstance(state[0], MergedIteratorState)
727+
assert isinstance(state, MergedIteratorState)
737728

738729
assert len(fetcher.dataloader_iter.cache_states) == 1
739730
if num_workers == 0:
740-
assert state[0].state[0].num_batches_fetched == num_batches_fetched
731+
assert state.state[0].num_batches_fetched == num_batches_fetched
741732
return state
742733

743734
dataset, random_sampler = create_dataset_sampler()
@@ -754,7 +745,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched):
754745

755746
# (A) capture the state after fetching 4 batches
756747
state = fetch(fetcher, prefetch_iter, 4)
757-
state = deepcopy(state[0])
748+
state = deepcopy(state)
758749

759750
# (B) simulate 2 additional batches
760751
batch05 = next(prefetch_iter)

tests/tests_pytorch/utilities/test_fetching.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,14 @@ def test_empty_prefetch_iterator(dataset_cls, prefetch_batches):
101101

102102

103103
def test_misconfiguration_error():
104-
105104
fetcher = DataFetcher()
105+
loader = DataLoader(range(10))
106+
fetcher.setup(loader)
107+
assert fetcher.loaders == loader
106108
with pytest.raises(
107109
MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context."
108110
):
109-
loader = DataLoader(range(10))
110-
fetcher.setup(loader)
111-
assert fetcher.loaders[0] == loader
112111
fetcher.loader_iters
113-
114112
iter(fetcher)
115113
assert fetcher.loader_iters
116114

0 commit comments

Comments
 (0)