Skip to content

Commit 8e9780b

Browse files
nandwalritikotaj
andauthored
fix mypy typing errors in pytorch_lightning.utilities.data.py (#13901)
Co-authored-by: otaj <[email protected]>
1 parent 9b01a0f commit 8e9780b

File tree

5 files changed

+36
-49
lines changed

5 files changed

+36
-49
lines changed

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ warn_no_return = "False"
5353
module = [
5454
"pytorch_lightning.callbacks.progress.rich_progress",
5555
"pytorch_lightning.trainer.trainer",
56-
"pytorch_lightning.tuner.batch_size_scaling",
57-
"pytorch_lightning.utilities.data",
58-
"lightning_lite.utilities.data",
56+
"pytorch_lightning.tuner.batch_size_scaling"
5957
]
6058
ignore_errors = "True"

src/lightning_lite/utilities/data.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union
2222

2323
from lightning_utilities.core.inheritance import get_all_subclasses
24-
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
24+
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler
2525

2626
from lightning_lite.utilities.enums import LightningEnum
2727
from lightning_lite.utilities.exceptions import MisconfigurationException
@@ -33,7 +33,8 @@ class _WrapAttrTag(LightningEnum):
3333
SET = "set"
3434
DEL = "del"
3535

36-
def __call__(self, *args):
36+
def __call__(self, *args: Any) -> None:
37+
fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]]
3738
if self == self.SET:
3839
fn = setattr
3940
else:
@@ -45,20 +46,20 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
4546
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
4647

4748

48-
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
49+
def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool:
4950
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
5051
infinite dataloader."""
5152
try:
5253
# try getting the length
53-
if len(dataloader) == 0:
54+
if len(dataloader) == 0: # type: ignore [arg-type]
5455
rank_zero_warn(
5556
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
5657
)
5758
has_len = True
5859
except (TypeError, NotImplementedError):
5960
has_len = False
6061

61-
if has_len and has_iterable_dataset(dataloader):
62+
if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader):
6263
rank_zero_warn(
6364
"Your `IterableDataset` has `__len__` defined."
6465
" In combination with multi-process data loading (when num_workers > 1),"
@@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
7677

7778
def _get_dataloader_init_args_and_kwargs(
7879
dataloader: DataLoader,
79-
sampler: Optional[Sampler],
80+
sampler: Union[Sampler, Iterable],
8081
disallow_batch_sampler: bool = False,
8182
) -> Tuple[Tuple[Any], Dict[str, Any]]:
8283
if not isinstance(dataloader, DataLoader):
@@ -99,7 +100,7 @@ def _get_dataloader_init_args_and_kwargs(
99100
arg_names = ()
100101

101102
# get the dataloader instance `__init__` parameters
102-
params = dict(inspect.signature(dataloader.__init__).parameters)
103+
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
103104
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
104105
if has_variadic_kwargs:
105106
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
@@ -141,36 +142,36 @@ def _get_dataloader_init_args_and_kwargs(
141142
}
142143
# the dataloader has required args which we could not extract from the existing attributes
143144
if required_args:
144-
required_args = sorted(required_args)
145+
sorted_required_args = sorted(required_args)
145146
dataloader_cls_name = dataloader.__class__.__name__
146-
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
147+
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
147148
raise MisconfigurationException(
148149
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
149150
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
150-
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
151-
"`*_dataloader` hook of your module, we will do this for you."
151+
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
152+
"inside a `*_dataloader` hook of your module, we will do this for you."
152153
f" Otherwise, define {missing_args_message} inside your `__init__`."
153154
)
154155

155156
if not has_variadic_kwargs:
156157
# the dataloader signature does not allow keyword arguments that need to be passed
157158
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
158159
if missing_kwargs:
159-
missing_kwargs = sorted(missing_kwargs)
160+
sorted_missing_kwargs = sorted(missing_kwargs)
160161
dataloader_cls_name = dataloader.__class__.__name__
161162
raise TypeError(
162163
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
163164
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
164-
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
165-
"add the `__init__` arguments or allow passing `**kwargs`"
165+
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
166+
"class, add the `__init__` arguments or allow passing `**kwargs`"
166167
)
167168

168169
return dl_args, dl_kwargs
169170

170171

171172
def _dataloader_init_kwargs_resolve_sampler(
172173
dataloader: DataLoader,
173-
sampler: Optional[Sampler],
174+
sampler: Union[Sampler, Iterable],
174175
disallow_batch_sampler: bool = False,
175176
) -> Dict[str, Any]:
176177
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
@@ -334,7 +335,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
334335
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
335336

336337
@functools.wraps(method)
337-
def wrapper(obj: Any, *args: Any):
338+
def wrapper(obj: Any, *args: Any) -> None:
338339
# First, let's find out if we're the first in inheritance chain calling the patched method.
339340
name, *_ = args
340341
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))

src/pytorch_lightning/strategies/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def _convert_to_poptorch_loader(
245245
return dataloader
246246

247247
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
248-
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
248+
dataloader, sampler, mode, self.replication_factor > 1
249249
)
250250
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
251251
dataloader = _reinstantiate_wrapped_cls(

src/pytorch_lightning/utilities/auto_restart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class FastForwardSampler(Sampler):
6262
samples seen in the last iterations (for the current worker).
6363
"""
6464

65-
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
65+
def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None:
6666
super().__init__(data_source=None)
6767
self._sampler = sampler
6868
self.restarting: bool = False

src/pytorch_lightning/utilities/data.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
)
3131

3232
import pytorch_lightning as pl
33-
from lightning_lite.utilities import LightningEnum
3433
from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
3534
from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset
3635
from lightning_lite.utilities.data import has_len as new_has_len
@@ -41,24 +40,13 @@
4140
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4241
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
4342

44-
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
43+
# might be supported in later releases, see https://github.com/python/mypy/pull/13297
44+
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
4545

4646
warning_cache = WarningCache()
4747

4848

49-
class _WrapAttrTag(LightningEnum):
50-
SET = "set"
51-
DEL = "del"
52-
53-
def __call__(self, *args):
54-
if self == self.SET:
55-
fn = setattr
56-
else:
57-
fn = delattr
58-
return fn(*args)
59-
60-
61-
def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
49+
def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]:
6250
if isinstance(batch, Tensor):
6351
if batch.ndim == 0:
6452
yield 1
@@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int:
10997

11098
def has_len_all_ranks(
11199
dataloader: DataLoader,
112-
strategy: "pl.Strategy",
100+
strategy: "pl.strategies.Strategy",
113101
model: Union["pl.LightningModule", "pl.LightningDataModule"],
114102
) -> bool:
115103
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
@@ -151,14 +139,14 @@ def has_len_all_ranks(
151139
return has_len
152140

153141

154-
def get_len(dataloader: DataLoader) -> Union[int, float]:
142+
def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]:
155143
"""Return the length of the given DataLoader.
156144
157145
If ``__len__`` method is not implemented, return float('inf').
158146
"""
159147

160148
if new_has_len(dataloader):
161-
return len(dataloader)
149+
return len(dataloader) # type: ignore [arg-type]
162150

163151
return float("inf")
164152

@@ -173,7 +161,7 @@ def _update_dataloader(
173161

174162
def _get_dataloader_init_args_and_kwargs(
175163
dataloader: DataLoader,
176-
sampler: Optional[Sampler],
164+
sampler: Union[Sampler, Iterable],
177165
mode: Optional[RunningStage] = None,
178166
disallow_batch_sampler: bool = False,
179167
) -> Tuple[Tuple[Any], Dict[str, Any]]:
@@ -197,7 +185,7 @@ def _get_dataloader_init_args_and_kwargs(
197185
arg_names = ()
198186

199187
# get the dataloader instance `__init__` parameters
200-
params = dict(inspect.signature(dataloader.__init__).parameters)
188+
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
201189
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
202190
if has_variadic_kwargs:
203191
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
@@ -239,28 +227,28 @@ def _get_dataloader_init_args_and_kwargs(
239227
}
240228
# the dataloader has required args which we could not extract from the existing attributes
241229
if required_args:
242-
required_args = sorted(required_args)
230+
sorted_required_args = sorted(required_args)
243231
dataloader_cls_name = dataloader.__class__.__name__
244-
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
232+
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
245233
raise MisconfigurationException(
246234
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
247235
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
248-
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
249-
"`*_dataloader` hook of your module, we will do this for you."
236+
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
237+
"inside a `*_dataloader` hook of your module, we will do this for you."
250238
f" Otherwise, define {missing_args_message} inside your `__init__`."
251239
)
252240

253241
if not has_variadic_kwargs:
254242
# the dataloader signature does not allow keyword arguments that need to be passed
255243
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
256244
if missing_kwargs:
257-
missing_kwargs = sorted(missing_kwargs)
245+
sorted_missing_kwargs = sorted(missing_kwargs)
258246
dataloader_cls_name = dataloader.__class__.__name__
259247
raise MisconfigurationException(
260248
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
261249
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
262-
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
263-
"add the `__init__` arguments or allow passing `**kwargs`"
250+
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
251+
"class, add the `__init__` arguments or allow passing `**kwargs`"
264252
)
265253

266254
if _FaultTolerantMode.detect_current_mode().is_automatic:
@@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs(
273261

274262
def _dataloader_init_kwargs_resolve_sampler(
275263
dataloader: DataLoader,
276-
sampler: Optional[Sampler],
264+
sampler: Union[Sampler, Iterable],
277265
mode: Optional[RunningStage] = None,
278266
disallow_batch_sampler: bool = False,
279267
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)