1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from collections .abc import Sized
1516from dataclasses import asdict , dataclass , field
1617from typing import Any , Callable , Dict , Iterable , Iterator , List , Mapping , Optional , Sequence , Union
1718
@@ -53,23 +54,24 @@ class TensorRunningAccum:
5354
5455 def __init__ (self , window_length : int ):
5556 self .window_length = window_length
56- self .memory = None
57- self .current_idx : int = 0
58- self .last_idx : Optional [int ] = None
59- self .rotated : bool = False
57+ self .reset (window_length )
6058
6159 def reset (self , window_length : Optional [int ] = None ) -> None :
6260 """Empty the accumulator."""
63- if window_length is None :
64- window_length = self .window_length
65- self .__init__ (window_length )
61+ if window_length is not None :
62+ self .window_length = window_length
63+ self .memory : Optional [torch .Tensor ] = None
64+ self .current_idx : int = 0
65+ self .last_idx : Optional [int ] = None
66+ self .rotated : bool = False
6667
67- def last (self ):
68+ def last (self ) -> Optional [ torch . Tensor ] :
6869 """Get the last added element."""
6970 if self .last_idx is not None :
71+ assert isinstance (self .memory , torch .Tensor )
7072 return self .memory [self .last_idx ].float ()
7173
72- def append (self , x ) :
74+ def append (self , x : torch . Tensor ) -> None :
7375 """Add an element to the accumulator."""
7476 if self .memory is None :
7577 # tradeoff memory for speed by keeping the memory on device
@@ -88,20 +90,21 @@ def append(self, x):
8890 if self .current_idx == 0 :
8991 self .rotated = True
9092
91- def mean (self ):
93+ def mean (self ) -> Optional [ torch . Tensor ] :
9294 """Get mean value from stored elements."""
9395 return self ._agg_memory ("mean" )
9496
95- def max (self ):
97+ def max (self ) -> Optional [ torch . Tensor ] :
9698 """Get maximal value from stored elements."""
9799 return self ._agg_memory ("max" )
98100
99- def min (self ):
101+ def min (self ) -> Optional [ torch . Tensor ] :
100102 """Get minimal value from stored elements."""
101103 return self ._agg_memory ("min" )
102104
103- def _agg_memory (self , how : str ):
105+ def _agg_memory (self , how : str ) -> Optional [ torch . Tensor ] :
104106 if self .last_idx is not None :
107+ assert isinstance (self .memory , torch .Tensor )
105108 if self .rotated :
106109 return getattr (self .memory .float (), how )()
107110 return getattr (self .memory [: self .current_idx ].float (), how )()
@@ -139,7 +142,7 @@ def done(self) -> bool:
139142class CycleIterator :
140143 """Iterator for restarting a dataloader if it runs out of samples."""
141144
142- def __init__ (self , loader : Any , length : Optional [int ] = None , state : SharedCycleIteratorState = None ):
145+ def __init__ (self , loader : Any , length : Optional [Union [ int , float ] ] = None , state : SharedCycleIteratorState = None ):
143146 """
144147 Args:
145148 loader: the loader to restart for cyclic (and optionally infinite) sampling
@@ -184,6 +187,8 @@ def __next__(self) -> Any:
184187 Raises:
185188 StopIteration: if more then :attr:`length` batches have been returned
186189 """
190+ assert isinstance (self ._loader_iter , Iterator )
191+
187192 # Note: if self.length is `inf`, then the iterator will never stop
188193 if self .counter >= self .__len__ () or self .state .done :
189194 raise StopIteration
@@ -257,13 +262,13 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
257262 Returns:
258263 length: the length of `CombinedDataset`
259264 """
260- if mode not in CombinedDataset .COMPUTE_FUNCS .keys ():
265+ if mode not in self .COMPUTE_FUNCS .keys ():
261266 raise MisconfigurationException (f"Invalid Mode: { mode } " )
262267
263268 # extract the lengths
264269 all_lengths = self ._get_len_recursive (datasets )
265270
266- compute_func = CombinedDataset .COMPUTE_FUNCS [mode ]
271+ compute_func = self .COMPUTE_FUNCS [mode ]
267272
268273 if isinstance (all_lengths , (int , float )):
269274 length = all_lengths
@@ -272,8 +277,9 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
272277
273278 return length
274279
275- def _get_len_recursive (self , data ) -> int :
280+ def _get_len_recursive (self , data : Any ) -> Union [ int , float , List , Dict ] :
276281 if isinstance (data , Dataset ):
282+ assert isinstance (data , Sized )
277283 return len (data )
278284
279285 if isinstance (data , (float , int )):
@@ -290,13 +296,13 @@ def _get_len_recursive(self, data) -> int:
290296 return self ._get_len (data )
291297
292298 @staticmethod
293- def _get_len (dataset ) -> int :
299+ def _get_len (dataset : Any ) -> Union [ int , float ] :
294300 try :
295301 return len (dataset )
296302 except (TypeError , NotImplementedError ):
297303 return float ("inf" )
298304
299- def __len__ (self ) -> int :
305+ def __len__ (self ) -> Union [ int , float ] :
300306 """Return the minimum length of the datasets."""
301307 return self ._calc_num_data (self .datasets , self .mode )
302308
@@ -348,8 +354,8 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
348354 if self .mode == "max_size_cycle" :
349355 self ._wrap_loaders_max_size_cycle ()
350356
351- self ._loaders_iter_state_dict = None
352- self ._iterator = None # assigned in __iter__
357+ self ._loaders_iter_state_dict : Optional [ Dict ] = None
358+ self ._iterator : Optional [ Iterator ] = None # assigned in __iter__
353359
354360 @staticmethod
355361 def _state_dict_fn (iterator : Optional [Iterator ], has_completed : int ) -> Dict :
@@ -384,7 +390,7 @@ def state_dict(self, has_completed: bool = False) -> Dict:
384390 has_completed = has_completed ,
385391 )
386392
387- def load_state_dict (self , state_dict ) -> None :
393+ def load_state_dict (self , state_dict : Dict ) -> None :
388394 # store the samplers state.
389395 # They would be reloaded once the `CombinedIterator` as been created
390396 # and the workers are created.
@@ -482,18 +488,18 @@ def __iter__(self) -> Any:
482488
483489 # prevent `NotImplementedError` from PyTorch:
484490 # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
485- def __getstate__patch__ (* _ ) :
491+ def __getstate__patch__ (* _ : Any ) -> Dict :
486492 return {}
487493
488- _BaseDataLoaderIter .__getstate__ = __getstate__patch__
494+ _BaseDataLoaderIter .__getstate__ = __getstate__patch__ # type: ignore[assignment]
489495 iterator = CombinedLoaderIterator (self .loaders )
490496 # handle fault tolerant restart logic.
491497 self .on_restart (iterator )
492498 self ._iterator = iterator
493499 return iterator
494500
495501 @staticmethod
496- def _calc_num_batches (loaders : Any , mode = "min_size" ) -> Union [int , float ]:
502+ def _calc_num_batches (loaders : Any , mode : str = "min_size" ) -> Union [int , float ]:
497503 """Compute the length (aka the number of batches) of `CombinedLoader`.
498504
499505 Args:
@@ -509,16 +515,16 @@ def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
509515 return all_lengths
510516 return _nested_calc_num_data (all_lengths , max if mode == "max_size_cycle" else min )
511517
512- def __len__ (self ) -> int :
518+ def __len__ (self ) -> Union [ int , float ] :
513519 return self ._calc_num_batches (self .loaders , mode = self .mode )
514520
515521 @staticmethod
516- def _shutdown_workers_and_reset_iterator (dataloader ) -> None :
522+ def _shutdown_workers_and_reset_iterator (dataloader : DataLoader ) -> None :
517523 if hasattr (dataloader , "_iterator" ) and isinstance (dataloader ._iterator , _MultiProcessingDataLoaderIter ):
518524 dataloader ._iterator ._shutdown_workers ()
519525 dataloader ._iterator = None
520526
521- def reset (self ):
527+ def reset (self ) -> None :
522528 if self ._iterator :
523529 self ._iterator ._loader_iters = None
524530 if self .loaders is not None :
@@ -535,7 +541,7 @@ def __init__(self, loaders: Any):
535541 loaders: the loaders to sample from. Can be all kind of collection
536542 """
537543 self .loaders = loaders
538- self ._loader_iters = None
544+ self ._loader_iters : Any = None
539545
540546 @property
541547 def loader_iters (self ) -> Any :
@@ -584,7 +590,9 @@ def create_loader_iters(
584590 return apply_to_collection (loaders , Iterable , iter , wrong_dtype = (Sequence , Mapping ))
585591
586592
587- def _nested_calc_num_data (data : Union [Mapping , Sequence ], compute_func : Callable ):
593+ def _nested_calc_num_data (
594+ data : Union [Mapping , Sequence ], compute_func : Callable [[List [Union [int , float ]]], Union [int , float ]]
595+ ) -> Union [int , float ]:
588596
589597 if isinstance (data , (float , int )):
590598 return data
0 commit comments