1515from abc import ABC , abstractmethod
1616from collections .abc import Iterable , Iterator
1717from copy import deepcopy
18- from typing import Any , Callable , Generator , List , Optional , Tuple , Union
18+ from typing import Any , Callable , List , Optional , Tuple
1919
2020import torch
2121from torch .utils .data .dataloader import DataLoader
@@ -58,28 +58,28 @@ def fetching_function(self) -> Any:
5858 def prefetching (self ) -> None :
5959 """Override with your own pre-fetching logic."""
6060
61+ def on_fetch_start (self ) -> Any :
62+ """Hook to override to handle the logic before fetching a batch."""
63+
64+ def on_fetch_end (self , batch : Any , start_output : Any ) -> None :
65+ """Hook to extend which handles the logic after fetching a batch."""
66+
67+ def wait (self ) -> None :
68+ """Hook to override to indicate the `DataFetcher` to wait for an event."""
69+
6170 def __init__ (self , prefetch_batches : int = 0 ) -> None :
6271 if prefetch_batches < 0 :
6372 raise MisconfigurationException ("`prefetch_batches` should at least be 0." )
6473 self .prefetch_batches = prefetch_batches
65-
66- self .dataloader : Optional [Union [DataLoader , CombinedLoader ]] = None
74+ self .dataloader : Optional [Iterable ] = None
6775 self .dataloader_iter : Optional [Iterator ] = None
68-
69- self .batch_to_device : Optional [Callable ]
70-
71- self .batches : List
72- self .fetched : int
73- self .done : bool
74-
76+ self .batch_to_device : Optional [Callable ] = None
7577 self .reset ()
7678
7779 def setup (self , dataloader : Iterable , batch_to_device : Optional [Callable ] = None ) -> None :
7880 self ._add_capture_metadata_collate (dataloader )
79-
8081 self .dataloader = dataloader
8182 self .batch_to_device = batch_to_device
82-
8383 self ._attach_data_fetcher ()
8484
8585 @staticmethod
@@ -92,8 +92,8 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None:
9292
9393 apply_to_collection (dataloader , DataLoader , _add_capture_metadata_collate )
9494
95- def _apply_patch (self ):
96- def _apply_patch_fn (loader : DataLoader , iterator : Iterator ):
95+ def _apply_patch (self ) -> None :
96+ def _apply_patch_fn (loader : DataLoader , iterator : Iterator ) -> None :
9797 if isinstance (loader , CycleIterator ):
9898 loader = loader .loader
9999 # cycle_iterator = iterator
@@ -158,13 +158,13 @@ def loader_iters(self) -> List[Iterator]:
158158
159159 @property
160160 def state (self ) -> Any :
161- def collect_state (iterator : Iterator ):
161+ def collect_state (iterator : Iterator ) -> Any :
162162 return iterator .state
163163
164164 return apply_to_collection (self .loader_iters , Iterator , collect_state )
165165
166- def _attach_data_fetcher (self ):
167- def _attach_data_fetcher_fn (loader : DataLoader ):
166+ def _attach_data_fetcher (self ) -> None :
167+ def _attach_data_fetcher_fn (loader : DataLoader ) -> None :
168168 if isinstance (loader , CycleIterator ):
169169 loader = loader .loader
170170
@@ -173,7 +173,7 @@ def _attach_data_fetcher_fn(loader: DataLoader):
173173
174174 apply_to_collection (self .loaders , (DataLoader , CycleIterator ), _attach_data_fetcher_fn )
175175
176- def __iter__ (self ) -> Generator [ Tuple [ Any , bool ], None , None ] :
176+ def __iter__ (self ) -> "AbstractDataFetcher" :
177177 if self .dataloader is None :
178178 raise MisconfigurationException ("The iterate hasn't been provided. HINT: Did you call setup function ?." )
179179 self .reset ()
@@ -184,11 +184,11 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
184184 self .prefetching ()
185185 return self
186186
187- def __next__ (self ):
187+ def __next__ (self ) -> Any :
188188 return self .fetching_function ()
189189
190190 def reset (self ) -> None :
191- self .batches : List = []
191+ self .batches : List [ Any ] = []
192192 self .fetched : int = 0
193193 self .done : bool = False
194194
@@ -217,18 +217,13 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N
217217 super ().__init__ (prefetch_batches = prefetch_batches )
218218 self .store_on_device = store_on_device
219219
220- def on_fetch_start (self ) -> None :
221- """Hook to override to handle the logic before fetching a batch."""
222-
223- def on_fetch_end (self , batch , on_fetch_start_output : Optional [Any ] = None ) -> None :
220+ def on_fetch_end (self , batch : Any , start_output : Any ) -> None :
224221 """Hook to extend which handles the logic after fetching a batch."""
225222 self .batches .append (batch )
226223
227- def wait (self ) -> None :
228- """Hook to override to indicate the `DataFetcher` to wait for an event."""
229-
230224 def prefetching (self ) -> None :
231225 iterator = self .dataloader_iter
226+ assert iterator is not None
232227 for _ in range (self .prefetch_batches ):
233228 try :
234229 self ._fetch_next_batch (iterator )
@@ -282,23 +277,21 @@ class InterBatchParallelDataFetcher(DataFetcher):
282277 batch 2: [HtoD] [forward][backward]
283278 """
284279
285- def __init__ (self , * args , ** kwargs ) -> None :
280+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
286281 super ().__init__ (* args , ** kwargs )
287282 self .cuda_stream = torch .cuda .Stream ()
288283 self .events : List [torch .cuda .Event ] = []
289284
290- def move_to_device (self , batch ) :
285+ def move_to_device (self , batch : Any ) -> Any :
291286 with torch .cuda .stream (self .cuda_stream ):
292287 return super ().move_to_device (batch )
293288
294289 def on_fetch_start (self ) -> "torch.cuda.Event" :
295290 # create a cuda event used to record the async stream of data to device.
296291 return torch .cuda .Event ()
297292
298- def on_fetch_end (self , batch , event : torch .cuda .Event ) -> None :
299- super ().on_fetch_end (batch )
300-
301- # record event and store the event
293+ def on_fetch_end (self , batch : Any , event : torch .cuda .Event ) -> None :
294+ self .batches .append (batch )
302295 event .record ()
303296 self .events .append (event )
304297
@@ -308,7 +301,7 @@ def wait(self) -> None:
308301 event .wait ()
309302
310303
311- class StepFuncDataLoaderIter :
304+ class StepFuncDataLoaderIter ( Iterator ) :
312305
313306 """This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user
314307 control."""
@@ -317,9 +310,6 @@ def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"):
317310 self .iterator = iterator
318311 self .data_fetcher = data_fetcher
319312
320- def __iter__ (self ) -> "StepFuncDataLoaderIter" :
321- return self
322-
323313 def __next__ (self ) -> Any :
324314 try :
325315 data = next (self .iterator )
@@ -349,12 +339,14 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
349339 ...
350340 """
351341
352- def __init__ (self ):
342+ def __init__ (self ) -> None :
353343 super ().__init__ ()
354344 self .store_on_device = False
355345
356346 def prefetching (self ) -> None :
357- self .iterator = iter (StepFuncDataLoaderIter (self .dataloader_iter , self ))
347+ iterator = self .dataloader_iter
348+ assert iterator is not None
349+ self .iterator = iter (StepFuncDataLoaderIter (iterator , self ))
358350
359351 def fetching_function (self ) -> Tuple [int , Tuple [Iterator , bool ]]:
360352 if not self .done :
0 commit comments