Skip to content

Commit 5ad5ba5

Browse files
authored
Refactor fetching function (#11516)
1 parent 075b880 commit 5ad5ba5

File tree

2 files changed

+20
-34
lines changed

2 files changed

+20
-34
lines changed

pytorch_lightning/utilities/fetching.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -228,44 +228,33 @@ def wait(self) -> None:
228228
"""Hook to override to indicate the `DataFetcher` to wait for an event."""
229229

230230
def prefetching(self) -> None:
231+
iterator = self.dataloader_iter
231232
for _ in range(self.prefetch_batches):
232233
try:
233-
self._fetch_next_batch()
234+
self._fetch_next_batch(iterator)
234235
except StopIteration:
235236
break
236237

237-
def fetching_function(self) -> Optional[Tuple[Any, bool]]:
238-
if self.done:
239-
while self.batches:
240-
return self._get_queued_batch()
241-
raise StopIteration
238+
def fetching_function(self) -> Tuple[Any, bool]:
239+
if self.batches:
240+
batch = self.batches.pop(0)
242241
else:
242+
# empty iterator, no prefetching done
243+
raise StopIteration
244+
if not self.done:
245+
assert self.dataloader_iter is not None
243246
try:
244-
yield_batch = self.batches.pop(0)
245-
self._fetch_next_batch()
246-
# wait for batch to be available.
247-
self.wait()
248-
# yield last and has next
249-
return self.move_to_device(yield_batch), False
247+
self._fetch_next_batch(self.dataloader_iter)
250248
except StopIteration:
251-
self.batches.insert(0, yield_batch)
252249
self.done = True
253-
return self._get_queued_batch()
254-
255-
except IndexError:
256-
raise StopIteration
250+
self.wait()
251+
return self.move_to_device(batch), len(self.batches) == 0
257252

258-
def _fetch_next_batch(self):
259-
data = self.on_fetch_start()
260-
batch = next(self.dataloader_iter)
253+
def _fetch_next_batch(self, iterator: Iterator) -> None:
254+
start_output = self.on_fetch_start()
255+
batch = next(iterator)
261256
self.fetched += 1
262-
self.on_fetch_end(batch, data)
263-
264-
def _get_queued_batch(self) -> Tuple[Any, bool]:
265-
batch = self.batches.pop(0)
266-
is_last = len(self.batches) == 0
267-
self.wait()
268-
return self.move_to_device(batch), is_last
257+
self.on_fetch_end(batch, start_output)
269258

270259
def move_to_device(self, batch: Any) -> Any:
271260
if self.store_on_device and self.batch_to_device is not None:
@@ -367,7 +356,7 @@ def __init__(self):
367356
def prefetching(self) -> None:
368357
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
369358

370-
def fetching_function(self):
371-
while not self.done:
359+
def fetching_function(self) -> Tuple[int, Tuple[Iterator, bool]]:
360+
if not self.done:
372361
return self.fetched, (self.iterator, self.done)
373362
raise StopIteration

tests/utilities/test_fetching.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ def __iter__(self):
5757

5858
def generate():
5959
generated = []
60-
for idx, data in enumerate(iterator, 1):
61-
if iterator.done:
62-
assert iterator.fetched == 3
63-
else:
64-
assert iterator.fetched == (idx + prefetch_batches)
60+
for idx, data in enumerate(iterator, prefetch_batches + 1):
61+
assert iterator.fetched == 3 if iterator.done else idx
6562
generated.append(data)
6663
return generated
6764

0 commit comments

Comments
 (0)