2121from typing import Any , Callable , Dict , Generator , Iterable , Optional , Tuple , Type , Union
2222
2323from lightning_utilities .core .inheritance import get_all_subclasses
24- from torch .utils .data import BatchSampler , DataLoader , IterableDataset , Sampler
24+ from torch .utils .data import BatchSampler , DataLoader , Dataset , IterableDataset , Sampler
2525
2626from lightning_lite .utilities .enums import LightningEnum
2727from lightning_lite .utilities .exceptions import MisconfigurationException
@@ -33,7 +33,8 @@ class _WrapAttrTag(LightningEnum):
3333 SET = "set"
3434 DEL = "del"
3535
36- def __call__ (self , * args ):
36+ def __call__ (self , * args : Any ) -> None :
37+ fn : Union [Callable [[object , str ], None ], Callable [[object , str , Any ], None ]]
3738 if self == self .SET :
3839 fn = setattr
3940 else :
@@ -45,20 +46,20 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
4546 return hasattr (dataloader , "dataset" ) and isinstance (dataloader .dataset , IterableDataset )
4647
4748
48- def has_len (dataloader : Union [DataLoader , Iterable ]) -> bool :
49+ def has_len (dataloader : Union [DataLoader , Iterable , Dataset ]) -> bool :
4950 """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
5051 infinite dataloader."""
5152 try :
5253 # try getting the length
53- if len (dataloader ) == 0 :
54+ if len (dataloader ) == 0 : # type: ignore [arg-type]
5455 rank_zero_warn (
5556 f"`{ dataloader .__class__ .__name__ } ` returned 0 length. Please make sure this was your intention."
5657 )
5758 has_len = True
5859 except (TypeError , NotImplementedError ):
5960 has_len = False
6061
61- if has_len and has_iterable_dataset (dataloader ):
62+ if has_len and isinstance ( dataloader , DataLoader ) and has_iterable_dataset (dataloader ):
6263 rank_zero_warn (
6364 "Your `IterableDataset` has `__len__` defined."
6465 " In combination with multi-process data loading (when num_workers > 1),"
@@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
7677
7778def _get_dataloader_init_args_and_kwargs (
7879 dataloader : DataLoader ,
79- sampler : Optional [Sampler ],
80+ sampler : Union [Sampler , Iterable ],
8081 disallow_batch_sampler : bool = False ,
8182) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
8283 if not isinstance (dataloader , DataLoader ):
@@ -99,7 +100,7 @@ def _get_dataloader_init_args_and_kwargs(
99100 arg_names = ()
100101
101102 # get the dataloader instance `__init__` parameters
102- params = dict (inspect .signature (dataloader .__init__ ).parameters )
103+ params = dict (inspect .signature (dataloader .__init__ ).parameters ) # type: ignore[misc]
103104 has_variadic_kwargs = any (p .kind is p .VAR_KEYWORD for p in params .values ())
104105 if has_variadic_kwargs :
105106 # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
@@ -141,36 +142,36 @@ def _get_dataloader_init_args_and_kwargs(
141142 }
142143 # the dataloader has required args which we could not extract from the existing attributes
143144 if required_args :
144- required_args = sorted (required_args )
145+ sorted_required_args = sorted (required_args )
145146 dataloader_cls_name = dataloader .__class__ .__name__
146- missing_args_message = ", " .join (f"`self.{ arg_name } `" for arg_name in required_args )
147+ missing_args_message = ", " .join (f"`self.{ arg_name } `" for arg_name in sorted_required_args )
147148 raise MisconfigurationException (
148149 f"Trying to inject custom `Sampler` into the `{ dataloader_cls_name } ` instance. "
149150 "This would fail as some of the `__init__` arguments are not available as instance attributes. "
150- f"The missing attributes are { required_args } . If you instantiate your `{ dataloader_cls_name } ` inside a "
151- "`*_dataloader` hook of your module, we will do this for you."
151+ f"The missing attributes are { sorted_required_args } . If you instantiate your `{ dataloader_cls_name } ` "
152+ "inside a `*_dataloader` hook of your module, we will do this for you."
152153 f" Otherwise, define { missing_args_message } inside your `__init__`."
153154 )
154155
155156 if not has_variadic_kwargs :
156157 # the dataloader signature does not allow keyword arguments that need to be passed
157158 missing_kwargs = (set (dl_kwargs ) | set (arg_names )) - params .keys ()
158159 if missing_kwargs :
159- missing_kwargs = sorted (missing_kwargs )
160+ sorted_missing_kwargs = sorted (missing_kwargs )
160161 dataloader_cls_name = dataloader .__class__ .__name__
161162 raise TypeError (
162163 f"Trying to inject parameters into the `{ dataloader_cls_name } ` instance. "
163164 "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
164- f"The missing arguments are { missing_kwargs } . HINT: If you wrote the `{ dataloader_cls_name } ` class, "
165- "add the `__init__` arguments or allow passing `**kwargs`"
165+ f"The missing arguments are { sorted_missing_kwargs } . HINT: If you wrote the `{ dataloader_cls_name } ` "
166+ "class, add the `__init__` arguments or allow passing `**kwargs`"
166167 )
167168
168169 return dl_args , dl_kwargs
169170
170171
171172def _dataloader_init_kwargs_resolve_sampler (
172173 dataloader : DataLoader ,
173- sampler : Optional [Sampler ],
174+ sampler : Union [Sampler , Iterable ],
174175 disallow_batch_sampler : bool = False ,
175176) -> Dict [str , Any ]:
176177 """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
@@ -334,7 +335,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
334335 :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
335336
336337 @functools .wraps (method )
337- def wrapper (obj : Any , * args : Any ):
338+ def wrapper (obj : Any , * args : Any ) -> None :
338339 # First, let's find out if we're the first in inheritance chain calling the patched method.
339340 name , * _ = args
340341 prev_call_name , prev_call_method = getattr (obj , "__pl_current_call" , (None , "method" ))
0 commit comments