@@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]:
262262 def min_len (self ) -> Union [int , float ]:
263263 return self ._calc_num_data (self .datasets , 'min_size' )
264264
265- @staticmethod
266- def _calc_num_data (datasets : Union [Sequence , Mapping ], mode : str ) -> Union [int , float ]:
265+ def _calc_num_data (self , datasets : Union [Sequence , Mapping ], mode : str ) -> Union [int , float ]:
267266 """
268267 Compute the length of `CombinedDataset` according to the `mode`.
269268
@@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
281280 raise MisconfigurationException (f"Invalid Mode: { mode } " )
282281
283282 # extract the lengths
284- all_lengths = apply_to_collection (
285- datasets , (Dataset , Iterable , type (None )), get_len , wrong_dtype = (Sequence , Mapping )
286- )
283+ all_lengths = self ._get_len_recursive (datasets )
287284
288285 compute_func = CombinedDataset .COMPUTE_FUNCS [mode ]
289286
@@ -294,6 +291,30 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
294291
295292 return length
296293
294+ def _get_len_recursive (self , data ) -> int :
295+ if isinstance (data , Dataset ):
296+ return len (data )
297+
298+ elif isinstance (data , (float , int )):
299+ return data
300+
301+ elif isinstance (data , Mapping ):
302+ if any (isinstance (v , (Mapping , Sequence , Dataset , Iterable )) for v in data .values ()):
303+ return {k : self ._get_len_recursive (v ) for k , v in data .items ()}
304+ elif isinstance (data , Sequence ):
305+ data = list (data )
306+ if any (isinstance (v , (Mapping , Sequence , Dataset , Iterable )) for v in data ):
307+ return [self ._get_len_recursive (v ) for v in data ]
308+
309+ return self ._get_len (data )
310+
311+ @staticmethod
312+ def _get_len (dataset ) -> int :
313+ try :
314+ return len (dataset )
315+ except (TypeError , NotImplementedError ):
316+ return float ('inf' )
317+
297318 def __len__ (self ) -> int :
298319 """Return the minimum length of the datasets."""
299320 return self ._calc_num_data (self .datasets , self .mode )
@@ -335,6 +356,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
335356 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
336357
337358 """
359+ if mode not in self .SUPPORTED_MODES :
360+ raise MisconfigurationException (f"Invalid Mode: { mode } " )
361+
338362 self .loaders = loaders
339363
340364 datasets = apply_to_collection (
@@ -343,9 +367,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
343367 # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
344368 self .dataset = CombinedDataset (datasets , mode )
345369
346- if mode not in self .SUPPORTED_MODES :
347- raise MisconfigurationException (f"Invalid Mode: { mode } " )
348-
349370 self .mode = mode
350371
351372 if self .mode == 'max_size_cycle' :
@@ -366,27 +387,13 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
366387 """
367388 all_lengths = apply_to_collection (self .loaders , Iterable , get_len , wrong_dtype = (Sequence , Mapping ))
368389
369- if isinstance (all_lengths , (int , float )):
370- length = all_lengths
371-
372- elif isinstance (all_lengths , Mapping ):
373- length = max (all_lengths .values ())
390+ length = _nested_calc_num_data (all_lengths , max )
374391
375- elif isinstance (all_lengths , Sequence ):
376- length = max (all_lengths )
377-
378- if isinstance (self .loaders , Mapping ):
379- self .loaders = type (self .loaders )({k : CycleIterator (v , length = length ) for k , v in self .loaders .items ()})
380-
381- elif isinstance (self .loaders , Sequence ):
382- self .loaders = type (self .loaders )([CycleIterator (v , length = length ) for v in self .loaders ])
383-
384- # dataloaders are iterable but not sequence
385- elif isinstance (self .loaders , Iterable ):
386- # only one dataloader, just keep it the same.
387- pass
388- else :
389- raise ValueError (f'Invalid Datatype for loaders: { type (self .loaders ).__name__ } ' )
392+ # multiple loaders
393+ if isinstance (self .loaders , (Sequence , Mapping )):
394+ self .loaders = apply_to_collection (
395+ self .loaders , Iterable , CycleIterator , length = length , wrong_dtype = (Sequence , Mapping )
396+ )
390397
391398 def __iter__ (self ) -> Any :
392399 """
@@ -490,7 +497,7 @@ def create_loader_iters(
490497
491498def _nested_calc_num_data (data : Union [Mapping , Sequence ], compute_func : Callable ):
492499
493- if isinstance (data , int ):
500+ if isinstance (data , ( float , int ) ):
494501 return data
495502
496503 if isinstance (data , Mapping ):
0 commit comments