Skip to content

Commit d8e5e7f

Browse files
Jungwon-Leeawaelchlicarmoccaotaj
authored
Fix mypy typing errors in pytorch_lightning/strategies/tpu_spawn.py (#13813)
Co-authored-by: awaelchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent 0fbfbf9 commit d8e5e7f

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ module = [
6666
"pytorch_lightning.strategies.ipu",
6767
"pytorch_lightning.strategies.sharded",
6868
"pytorch_lightning.strategies.sharded_spawn",
69-
"pytorch_lightning.strategies.tpu_spawn",
7069
"pytorch_lightning.trainer.callback_hook",
7170
"pytorch_lightning.trainer.connectors.callback_connector",
7271
"pytorch_lightning.trainer.connectors.data_connector",

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import io
1515
import os
16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -29,15 +29,17 @@
2929
from pytorch_lightning.plugins.precision import PrecisionPlugin
3030
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
3131
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
32+
from pytorch_lightning.strategies.strategy import TBroadcast
3233
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3334
from pytorch_lightning.trainer.states import TrainerFn
3435
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
36+
from pytorch_lightning.utilities.apply_func import apply_to_collection
3537
from pytorch_lightning.utilities.data import has_len
3638
from pytorch_lightning.utilities.distributed import ReduceOp
3739
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3840
from pytorch_lightning.utilities.optimizer import optimizers_to_device
3941
from pytorch_lightning.utilities.rank_zero import rank_zero_only
40-
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
42+
from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
4143

4244
if _TPU_AVAILABLE:
4345
import torch_xla.core.xla_env_vars as xenv
@@ -58,7 +60,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
5860
def __init__(
5961
self,
6062
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
61-
parallel_devices: Optional[List[int]] = None,
63+
parallel_devices: Optional[List[torch.device]] = None,
6264
checkpoint_io: Optional[CheckpointIO] = None,
6365
precision_plugin: Optional[PrecisionPlugin] = None,
6466
debug: bool = False,
@@ -72,6 +74,7 @@ def __init__(
7274
precision_plugin=precision_plugin,
7375
start_method="fork",
7476
)
77+
self._checkpoint_io: Optional[CheckpointIO]
7578
self.debug = debug
7679
self._launched = False
7780

@@ -95,17 +98,16 @@ def root_device(self) -> torch.device:
9598
return xm.xla_device()
9699

97100
@staticmethod
98-
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
99-
if not isinstance(dataloaders, list):
100-
dataloaders = [dataloaders]
101-
102-
for dataloader in dataloaders:
101+
def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
102+
def check_has_len(dataloader: DataLoader) -> None:
103103
if not has_len(dataloader):
104104
raise MisconfigurationException(
105105
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
106106
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
107107
)
108108

109+
apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len)
110+
109111
@staticmethod
110112
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
111113
"""Validate and fail fast if the dataloaders were passed directly to fit."""
@@ -118,32 +120,37 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
118120
)
119121
for source in sources:
120122
if not source.is_module():
123+
assert source.instance is not None
124+
assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule))
121125
TPUSpawnStrategy._validate_dataloader(source.instance)
122126

123-
def connect(self, model: "pl.LightningModule") -> None:
127+
def connect(self, model: "pl.LightningModule") -> None: # type: ignore
124128
TPUSpawnStrategy._validate_patched_dataloaders(model)
125129
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
126130
return super().connect(model)
127131

128-
def _configure_launcher(self):
132+
def _configure_launcher(self) -> None:
129133
self._launcher = _XLALauncher(self)
130134

131135
def setup(self, trainer: "pl.Trainer") -> None:
136+
assert self.accelerator
132137
self.accelerator.setup(trainer)
133138

134139
if self.debug:
135140
os.environ["PT_XLA_DEBUG"] = "1"
136141

142+
assert self.model
137143
shared_params = find_shared_parameters(self.model)
138144
self.model_to_device()
145+
assert isinstance(self.model.module, Module)
139146
set_shared_parameters(self.model.module, shared_params)
140147
self.setup_precision_plugin()
141148

142149
if trainer.state.fn == TrainerFn.FITTING:
143150
self.setup_optimizers(trainer)
144151
optimizers_to_device(self.optimizers, self.root_device)
145152

146-
def _setup_model(self, model: Module) -> Module:
153+
def _setup_model(self, model: Module) -> Module: # type: ignore
147154
return model
148155

149156
@property
@@ -168,11 +175,11 @@ def configure_ddp(self) -> None:
168175
def model_to_device(self) -> None:
169176
self.model = self.wrapped_model.to(self.root_device)
170177

171-
def barrier(self, name: Optional[str] = None) -> None:
178+
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
172179
if self.is_distributed:
173180
rendezvous(name)
174181

175-
def broadcast(self, obj: object, src: int = 0) -> object:
182+
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
176183
if not self.is_distributed:
177184
return obj
178185
buffer = io.BytesIO()
@@ -184,7 +191,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
184191
obj = torch.load(buffer)
185192
return obj
186193

187-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
194+
def reduce(
195+
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
196+
) -> Tensor:
188197
if not isinstance(output, Tensor):
189198
output = torch.tensor(output, device=self.root_device)
190199

@@ -203,20 +212,23 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
203212

204213
return output
205214

206-
def _worker_setup(self, process_idx: int):
215+
def _worker_setup(self, process_idx: int) -> None:
207216
self._launched = True
208217
self.set_world_ranks(process_idx)
209218
rank_zero_only.rank = self.global_rank
210219

211-
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
220+
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
221+
assert self.model is not None
212222
with self.precision_plugin.val_step_context():
213223
return self.model(*args, **kwargs)
214224

215-
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
225+
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
226+
assert self.model is not None
216227
with self.precision_plugin.test_step_context():
217228
return self.model(*args, **kwargs)
218229

219-
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
230+
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
231+
assert self.model is not None
220232
with self.precision_plugin.predict_step_context():
221233
return self.model(*args, **kwargs)
222234

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def is_defined(self) -> bool:
516516
return not self.is_module() or is_overridden(self.name, self.instance)
517517

518518
def is_module(self) -> bool:
519-
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
519+
"""Returns whether the DataLoader source is a LightningModule or a LightningDataModule.
520520
521521
It does not check whether ``*_dataloader`` methods are actually overridden.
522522
"""

src/pytorch_lightning/utilities/apply_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def apply_to_collection(
7676
dtype: Union[type, Any, Tuple[Union[type, Any]]],
7777
function: Callable,
7878
*args: Any,
79-
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
79+
wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None,
8080
include_none: bool = True,
8181
**kwargs: Any,
8282
) -> Any:

0 commit comments

Comments
 (0)