2121from typing import Any , Callable , Dict , Generator , Iterable , Optional , Tuple , Type , Union
2222
2323from lightning_utilities .core .inheritance import get_all_subclasses
24- from torch .utils .data import BatchSampler , DataLoader , IterableDataset , Sampler
24+ from torch .utils .data import BatchSampler , DataLoader , Dataset , IterableDataset , Sampler
2525
2626from lightning_lite .utilities .enums import LightningEnum
2727from lightning_lite .utilities .exceptions import MisconfigurationException
@@ -34,6 +34,7 @@ class _WrapAttrTag(LightningEnum):
3434 DEL = "del"
3535
3636 def __call__ (self , * args : Any ) -> None :
37+ fn : Union [Callable [[object , str ], None ], Callable [[object , str , Any ], None ]]
3738 if self == self .SET :
3839 fn = setattr
3940 else :
@@ -45,20 +46,20 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
4546 return hasattr (dataloader , "dataset" ) and isinstance (dataloader .dataset , IterableDataset )
4647
4748
48- def has_len (dataloader : Union [DataLoader , Iterable ]) -> bool :
49+ def has_len (dataloader : Union [DataLoader , Iterable , Dataset ]) -> bool :
4950 """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
5051 infinite dataloader."""
5152 try :
5253 # try getting the length
53- if len (dataloader ) == 0 :
54+ if len (dataloader ) == 0 : # type: ignore [arg-type]
5455 rank_zero_warn (
5556 f"`{ dataloader .__class__ .__name__ } ` returned 0 length. Please make sure this was your intention."
5657 )
5758 has_len = True
5859 except (TypeError , NotImplementedError ):
5960 has_len = False
6061
61- if has_len and has_iterable_dataset (dataloader ):
62+ if has_len and isinstance ( dataloader , DataLoader ) and has_iterable_dataset (dataloader ):
6263 rank_zero_warn (
6364 "Your `IterableDataset` has `__len__` defined."
6465 " In combination with multi-process data loading (when num_workers > 1),"
@@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
7677
7778def _get_dataloader_init_args_and_kwargs (
7879 dataloader : DataLoader ,
79- sampler : Optional [Sampler ],
80+ sampler : Union [Sampler , Iterable ],
8081 disallow_batch_sampler : bool = False ,
8182) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
8283 if not isinstance (dataloader , DataLoader ):
@@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs(
170171
171172def _dataloader_init_kwargs_resolve_sampler (
172173 dataloader : DataLoader ,
173- sampler : Optional [Sampler ],
174+ sampler : Union [Sampler , Iterable ],
174175 disallow_batch_sampler : bool = False ,
175176) -> Dict [str , Any ]:
176177 """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
0 commit comments