Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions src/pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,18 @@ def _store_dataloader_iter_state(
dataloader_iter.state.update(iter_name, state)

@property
def loaders(self) -> List[DataLoader]:
def loaders(self) -> Any:
if isinstance(self.dataloader, CombinedLoader):
loaders = self.dataloader.loaders
else:
loaders = [self.dataloader]
return loaders
return self.dataloader.loaders
return self.dataloader

@property
def loader_iters(self) -> List[Iterator]:
def loader_iters(self) -> Any:
if self.dataloader_iter is None:
raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.")

if isinstance(self.dataloader, CombinedLoader):
loader_iters = self.dataloader_iter.loader_iters
else:
loader_iters = [self.dataloader_iter]
return loader_iters
return self.dataloader_iter.loader_iters
return self.dataloader_iter

@property
def state(self) -> List[MergedIteratorState]:
Expand Down
17 changes: 4 additions & 13 deletions tests/tests_pytorch/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,16 +700,8 @@ def __len__(self):
return self.len


# TODO: test with `RandomGeneratorGetItemDataset`
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize(
"dataset_class",
[
SequentialGetItemDataset,
RandomGetItemDataset,
# RandomGeneratorGetItemDataset,
],
)
@pytest.mark.parametrize("dataset_class", [SequentialGetItemDataset, RandomGetItemDataset])
@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))])
@pytest.mark.parametrize("batch_size", [1, 2, 3])
def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size):
Expand All @@ -732,12 +724,11 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched):
_ = next(prefetch_iter)

state: List[MergedIteratorState] = fetcher.state
assert len(state) == 1
assert isinstance(state[0], MergedIteratorState)
assert isinstance(state, MergedIteratorState)

assert len(fetcher.dataloader_iter.cache_states) == 1
if num_workers == 0:
assert state[0].state[0].num_batches_fetched == num_batches_fetched
assert state.state[0].num_batches_fetched == num_batches_fetched
return state

dataset, random_sampler = create_dataset_sampler()
Expand All @@ -754,7 +745,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched):

# (A) capture the state after fetching 4 batches
state = fetch(fetcher, prefetch_iter, 4)
state = deepcopy(state[0])
state = deepcopy(state)

# (B) simulate 2 additional batches
batch05 = next(prefetch_iter)
Expand Down
8 changes: 3 additions & 5 deletions tests/tests_pytorch/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,14 @@ def test_empty_prefetch_iterator(dataset_cls, prefetch_batches):


def test_misconfiguration_error():

fetcher = DataFetcher()
loader = DataLoader(range(10))
fetcher.setup(loader)
assert fetcher.loaders == loader
with pytest.raises(
MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context."
):
loader = DataLoader(range(10))
fetcher.setup(loader)
assert fetcher.loaders[0] == loader
fetcher.loader_iters

iter(fetcher)
assert fetcher.loader_iters

Expand Down