@@ -228,44 +228,33 @@ def wait(self) -> None:
228228 """Hook to override to indicate the `DataFetcher` to wait for an event."""
229229
230230 def prefetching (self ) -> None :
231+ iterator = self .dataloader_iter
231232 for _ in range (self .prefetch_batches ):
232233 try :
233- self ._fetch_next_batch ()
234+ self ._fetch_next_batch (iterator )
234235 except StopIteration :
235236 break
236237
237- def fetching_function (self ) -> Optional [Tuple [Any , bool ]]:
238- if self .done :
239- while self .batches :
240- return self ._get_queued_batch ()
241- raise StopIteration
238+ def fetching_function (self ) -> Tuple [Any , bool ]:
239+ if self .batches :
240+ batch = self .batches .pop (0 )
242241 else :
242+ # empty iterator, no prefetching done
243+ raise StopIteration
244+ if not self .done :
245+ assert self .dataloader_iter is not None
243246 try :
244- yield_batch = self .batches .pop (0 )
245- self ._fetch_next_batch ()
246- # wait for batch to be available.
247- self .wait ()
248- # yield last and has next
249- return self .move_to_device (yield_batch ), False
247+ self ._fetch_next_batch (self .dataloader_iter )
250248 except StopIteration :
251- self .batches .insert (0 , yield_batch )
252249 self .done = True
253- return self ._get_queued_batch ()
254-
255- except IndexError :
256- raise StopIteration
250+ self .wait ()
251+ return self .move_to_device (batch ), len (self .batches ) == 0
257252
258- def _fetch_next_batch (self ) :
259- data = self .on_fetch_start ()
260- batch = next (self . dataloader_iter )
253+ def _fetch_next_batch (self , iterator : Iterator ) -> None :
254+ start_output = self .on_fetch_start ()
255+ batch = next (iterator )
261256 self .fetched += 1
262- self .on_fetch_end (batch , data )
263-
264- def _get_queued_batch (self ) -> Tuple [Any , bool ]:
265- batch = self .batches .pop (0 )
266- is_last = len (self .batches ) == 0
267- self .wait ()
268- return self .move_to_device (batch ), is_last
257+ self .on_fetch_end (batch , start_output )
269258
270259 def move_to_device (self , batch : Any ) -> Any :
271260 if self .store_on_device and self .batch_to_device is not None :
@@ -367,7 +356,7 @@ def __init__(self):
367356 def prefetching (self ) -> None :
368357 self .iterator = iter (StepFuncDataLoaderIter (self .dataloader_iter , self ))
369358
370- def fetching_function (self ):
371- while not self .done :
359+ def fetching_function (self ) -> Tuple [ int , Tuple [ Iterator , bool ]] :
360+ if not self .done :
372361 return self .fetched , (self .iterator , self .done )
373362 raise StopIteration
0 commit comments