1313# limitations under the License.
1414import io
1515import os
16- from typing import Any , Dict , List , Optional , Union
16+ from typing import Any , Dict , List , Mapping , Optional , Sequence , Union
1717
1818import torch
1919from torch import Tensor
2929from pytorch_lightning .plugins .precision import PrecisionPlugin
3030from pytorch_lightning .strategies .ddp_spawn import DDPSpawnStrategy
3131from pytorch_lightning .strategies .launchers .xla import _XLALauncher
32+ from pytorch_lightning .strategies .strategy import TBroadcast
3233from pytorch_lightning .trainer .connectors .data_connector import DataConnector
3334from pytorch_lightning .trainer .states import TrainerFn
3435from pytorch_lightning .utilities import _TPU_AVAILABLE , find_shared_parameters , set_shared_parameters
36+ from pytorch_lightning .utilities .apply_func import apply_to_collection
3537from pytorch_lightning .utilities .data import has_len
3638from pytorch_lightning .utilities .distributed import ReduceOp
3739from pytorch_lightning .utilities .exceptions import MisconfigurationException
3840from pytorch_lightning .utilities .optimizer import optimizers_to_device
3941from pytorch_lightning .utilities .rank_zero import rank_zero_only
40- from pytorch_lightning .utilities .types import _PATH , STEP_OUTPUT
42+ from pytorch_lightning .utilities .types import _PATH , EVAL_DATALOADERS , STEP_OUTPUT , TRAIN_DATALOADERS
4143
4244if _TPU_AVAILABLE :
4345 import torch_xla .core .xla_env_vars as xenv
@@ -58,7 +60,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
5860 def __init__ (
5961 self ,
6062 accelerator : Optional ["pl.accelerators.accelerator.Accelerator" ] = None ,
61- parallel_devices : Optional [List [int ]] = None ,
63+ parallel_devices : Optional [List [torch . device ]] = None ,
6264 checkpoint_io : Optional [CheckpointIO ] = None ,
6365 precision_plugin : Optional [PrecisionPlugin ] = None ,
6466 debug : bool = False ,
@@ -72,6 +74,7 @@ def __init__(
7274 precision_plugin = precision_plugin ,
7375 start_method = "fork" ,
7476 )
77+ self ._checkpoint_io : Optional [CheckpointIO ]
7578 self .debug = debug
7679 self ._launched = False
7780
@@ -95,17 +98,16 @@ def root_device(self) -> torch.device:
9598 return xm .xla_device ()
9699
97100 @staticmethod
98- def _validate_dataloader (dataloaders : Union [List [DataLoader ], DataLoader ]) -> None :
99- if not isinstance (dataloaders , list ):
100- dataloaders = [dataloaders ]
101-
102- for dataloader in dataloaders :
101+ def _validate_dataloader (dataloaders : Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ]) -> None :
102+ def check_has_len (dataloader : DataLoader ) -> None :
103103 if not has_len (dataloader ):
104104 raise MisconfigurationException (
105105 "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
106106 " HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
107107 )
108108
109+ apply_to_collection (dataloaders , dtype = object , wrong_dtype = (Sequence , Mapping ), function = check_has_len )
110+
109111 @staticmethod
110112 def _validate_patched_dataloaders (model : "pl.LightningModule" ) -> None :
111113 """Validate and fail fast if the dataloaders were passed directly to fit."""
@@ -118,32 +120,37 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
118120 )
119121 for source in sources :
120122 if not source .is_module ():
123+ assert source .instance is not None
124+ assert not isinstance (source .instance , (pl .LightningModule , pl .LightningDataModule ))
121125 TPUSpawnStrategy ._validate_dataloader (source .instance )
122126
123- def connect (self , model : "pl.LightningModule" ) -> None :
127+ def connect (self , model : "pl.LightningModule" ) -> None : # type: ignore
124128 TPUSpawnStrategy ._validate_patched_dataloaders (model )
125129 self .wrapped_model = xmp .MpModelWrapper (LightningDistributedModule (model ))
126130 return super ().connect (model )
127131
128- def _configure_launcher (self ):
132+ def _configure_launcher (self ) -> None :
129133 self ._launcher = _XLALauncher (self )
130134
131135 def setup (self , trainer : "pl.Trainer" ) -> None :
136+ assert self .accelerator
132137 self .accelerator .setup (trainer )
133138
134139 if self .debug :
135140 os .environ ["PT_XLA_DEBUG" ] = "1"
136141
142+ assert self .model
137143 shared_params = find_shared_parameters (self .model )
138144 self .model_to_device ()
145+ assert isinstance (self .model .module , Module )
139146 set_shared_parameters (self .model .module , shared_params )
140147 self .setup_precision_plugin ()
141148
142149 if trainer .state .fn == TrainerFn .FITTING :
143150 self .setup_optimizers (trainer )
144151 optimizers_to_device (self .optimizers , self .root_device )
145152
146- def _setup_model (self , model : Module ) -> Module :
153+ def _setup_model (self , model : Module ) -> Module : # type: ignore
147154 return model
148155
149156 @property
@@ -168,11 +175,11 @@ def configure_ddp(self) -> None:
168175 def model_to_device (self ) -> None :
169176 self .model = self .wrapped_model .to (self .root_device )
170177
171- def barrier (self , name : Optional [str ] = None ) -> None :
178+ def barrier (self , name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> None :
172179 if self .is_distributed :
173180 rendezvous (name )
174181
175- def broadcast (self , obj : object , src : int = 0 ) -> object :
182+ def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
176183 if not self .is_distributed :
177184 return obj
178185 buffer = io .BytesIO ()
@@ -184,7 +191,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
184191 obj = torch .load (buffer )
185192 return obj
186193
187- def reduce (self , output , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None ):
194+ def reduce (
195+ self , output : Union [Tensor , Any ], group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
196+ ) -> Tensor :
188197 if not isinstance (output , Tensor ):
189198 output = torch .tensor (output , device = self .root_device )
190199
@@ -203,20 +212,23 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
203212
204213 return output
205214
206- def _worker_setup (self , process_idx : int ):
215+ def _worker_setup (self , process_idx : int ) -> None :
207216 self ._launched = True
208217 self .set_world_ranks (process_idx )
209218 rank_zero_only .rank = self .global_rank
210219
211- def validation_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
220+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
221+ assert self .model is not None
212222 with self .precision_plugin .val_step_context ():
213223 return self .model (* args , ** kwargs )
214224
215- def test_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
225+ def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
226+ assert self .model is not None
216227 with self .precision_plugin .test_step_context ():
217228 return self .model (* args , ** kwargs )
218229
219- def predict_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
230+ def predict_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
231+ assert self .model is not None
220232 with self .precision_plugin .predict_step_context ():
221233 return self .model (* args , ** kwargs )
222234
0 commit comments