Skip to content

Commit 373c32e

Browse files
Fix yielding from iterator in LiteDataLoader (#10304)
* fix yielding form iterator * update description * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused code Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f44b8a7 commit 373c32e

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

pytorch_lightning/lite/wrappers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ def device(self) -> Optional[torch.device]:
174174
return self._device
175175

176176
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
177-
dataloader_iter = iter(self._dataloader)
177+
iterator = iter(self._dataloader)
178178
if self._device is None:
179-
return dataloader_iter
179+
yield from iterator
180180

181-
for item in dataloader_iter:
181+
for item in iterator:
182182
yield move_data_to_device(item, self._device)

tests/lite/test_wrappers.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,27 @@ def check_autocast(forward_input):
6060
assert out.dtype == torch.get_default_dtype()
6161

6262

63+
def test_lite_dataloader_iterator():
64+
"""Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
65+
device placement)."""
66+
dataloader = DataLoader(range(5), batch_size=2)
67+
lite_dataloader = _LiteDataLoader(dataloader)
68+
assert len(lite_dataloader) == len(dataloader) == 3
69+
70+
iterator = iter(dataloader)
71+
lite_iterator = iter(lite_dataloader)
72+
73+
assert torch.equal(next(iterator), next(lite_iterator))
74+
assert torch.equal(next(iterator), next(lite_iterator))
75+
assert torch.equal(next(iterator), next(lite_iterator))
76+
77+
with pytest.raises(StopIteration):
78+
next(iterator)
79+
80+
with pytest.raises(StopIteration):
81+
next(lite_iterator)
82+
83+
6384
@pytest.mark.parametrize(
6485
"src_device, dest_device",
6586
[
@@ -84,17 +105,6 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
84105
batch1 = next(iterator)
85106
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
86107

87-
with pytest.raises(StopIteration):
88-
batch1 = next(iterator)
89-
90-
lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device)
91-
iterator = iter(lite_dataloader)
92-
93-
batch0 = next(iterator)
94-
assert batch0 == 0
95-
96-
assert len(lite_dataloader) == 4
97-
98108

99109
def test_lite_optimizer_wraps():
100110
"""Test that the LiteOptimizer fully wraps the optimizer."""

0 commit comments

Comments
 (0)