1515from abc import ABC , abstractmethod
1616from collections .abc import Iterable , Iterator
1717from copy import deepcopy
18- from typing import Any , Callable , List , Optional , Tuple
18+ from typing import Any , Callable , List , Optional , Sized , Tuple
1919
2020import torch
2121from torch .utils .data .dataloader import DataLoader
3030 MergedIteratorState ,
3131 patch_dataloader_iterator ,
3232)
33+ from pytorch_lightning .utilities .data import has_len
3334from pytorch_lightning .utilities .exceptions import MisconfigurationException
3435from pytorch_lightning .utilities .imports import _fault_tolerant_training
3536
@@ -79,6 +80,8 @@ def __init__(self, prefetch_batches: int = 0) -> None:
7980 def setup (self , dataloader : Iterable , ** kwargs : Any ) -> None :
8081 self ._add_capture_metadata_collate (dataloader )
8182 self ._dataloader = dataloader
83+ _patch_dataloader_get_iterators ()
84+ self ._attach_data_fetcher ()
8285
8386 @property
8487 def dataloader (self ) -> Iterable :
@@ -172,8 +175,6 @@ def _attach_data_fetcher_fn(loader: DataLoader) -> None:
172175
173176 def __iter__ (self ) -> "AbstractDataFetcher" :
174177 self .reset ()
175- self ._attach_data_fetcher ()
176- _patch_dataloader_get_iterators ()
177178 self .dataloader_iter = iter (self .dataloader )
178179 self ._apply_patch ()
179180 self .prefetching ()
@@ -205,7 +206,7 @@ class DataFetcher(AbstractDataFetcher):
205206
206207 Args:
207208 prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
208- whether a batch is the last one (available with :attr:`self.done`).
209+ whether a batch is the last one (available with :attr:`self.done`) under any training setup .
209210 store_on_device: Whether to store the pre-fetched batches on device.
210211 """
211212
@@ -214,11 +215,13 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N
214215 self .store_on_device = store_on_device
215216 self .batch_to_device : Callable [[Any ], Any ] = _no_op_batch_to_device
216217 self .batches : List [Any ] = []
218+ self ._has_len = False
217219
218220 def setup ( # type: ignore[override]
219221 self , dataloader : Iterable , batch_to_device : Optional [Callable [[Any ], Any ]] = None
220222 ) -> None :
221223 super ().setup (dataloader )
224+ self ._has_len = has_len (dataloader )
222225 if batch_to_device is not None :
223226 self .batch_to_device = batch_to_device
224227
@@ -233,6 +236,9 @@ def prefetching(self) -> None:
233236 try :
234237 self ._fetch_next_batch (iterator )
235238 except StopIteration :
239+ # this would only happen when prefetch_batches > the number of batches available and makes
240+ # `fetching_function` jump directly to the empty iterator case without trying to fetch again
241+ self .done = True
236242 break
237243
238244 def fetching_function (self ) -> Any :
@@ -266,6 +272,11 @@ def _fetch_next_batch(self, iterator: Iterator) -> None:
266272 start_output = self .on_fetch_start ()
267273 batch = next (iterator )
268274 self .fetched += 1
275+ if not self .prefetch_batches and self ._has_len :
276+ # when we don't prefetch but the dataloader is sized, we use the length for `done`
277+ dataloader = self .dataloader
278+ assert isinstance (dataloader , Sized ) # `_has_len` is True
279+ self .done = self .fetched >= len (dataloader )
269280 self .on_fetch_end (batch , start_output )
270281
271282 def move_to_device (self , batch : Any ) -> Any :
@@ -360,7 +371,8 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
360371 ...
361372 """
362373
363- def __init__ (self ) -> None :
374+ def __init__ (self , prefetch_batches : int = 0 ) -> None :
375+ # prefetch batches is not used for this class
364376 super ().__init__ ()
365377 self .store_on_device = False
366378
0 commit comments