1313# limitations under the License.
1414import json
1515import os
16- from typing import Any , Callable , Dict , List , Optional , Union
16+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
1717
1818import torch
1919from torch import FloatTensor , Tensor
20- from torch .utils .data import DataLoader
20+ from torch .utils .data import DataLoader , Sampler
2121
2222import pytorch_lightning as pl
2323from pytorch_lightning .overrides .base import _LightningModuleWrapperBase , _LightningPrecisionModuleWrapperBase
2424from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
2525from pytorch_lightning .plugins .io .checkpoint_plugin import CheckpointIO
2626from pytorch_lightning .plugins .precision import PrecisionPlugin
2727from pytorch_lightning .strategies .parallel import ParallelStrategy
28+ from pytorch_lightning .strategies .strategy import TBroadcast
2829from pytorch_lightning .trainer .states import RunningStage , TrainerFn
2930from pytorch_lightning .utilities import _IPU_AVAILABLE , _POPTORCH_AVAILABLE , rank_zero_warn
3031from pytorch_lightning .utilities .apply_func import apply_to_collection
@@ -112,12 +113,12 @@ def __init__(
112113 self .device_iterations = device_iterations
113114 self .autoreport = autoreport
114115 self .autoreport_dir = autoreport_dir
115- self .poptorch_models = {}
116+ self .poptorch_models : Dict [ RunningStage , "poptorch.PoplarExecutor" ] = {}
116117 self ._training_opts = training_opts
117118 self ._inference_opts = inference_opts
118119
119120 if self .autoreport :
120- options = {"autoReport.all" : self .autoreport }
121+ options : Dict [ str , Any ] = {"autoReport.all" : self .autoreport }
121122 if self .autoreport_dir :
122123 self ._fs = get_filesystem (str (self .autoreport_dir ))
123124 self ._fs .makedirs (self .autoreport_dir , exist_ok = True )
@@ -139,6 +140,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
139140
140141 super ().setup (trainer )
141142
143+ assert self .lightning_module is not None
144+
142145 # disable the `optimizer_zero_grad` function by setting it to `None`.
143146 # this is because the IPU zeros the gradients internally
144147 self ._optimizer_zero_grad_original = self .lightning_module .optimizer_zero_grad
@@ -192,12 +195,14 @@ def replication_factor(self) -> int:
192195 if self ._inference_opts :
193196 return self ._inference_opts .replication_factor
194197
198+ assert self .parallel_devices
195199 return len (self .parallel_devices )
196-
197200 stage = self .lightning_module .trainer .state .stage
201+ assert stage is not None
198202 return self .poptorch_models [stage ]._options .toDict ()["replication_factor" ]
199203
200204 def _create_opts (self , training : bool ) -> "poptorch.Options" :
205+ assert self .lightning_module is not None
201206 opts = poptorch .Options ()
202207 opts .deviceIterations (self .device_iterations )
203208 opts .replicationFactor (self .replication_factor )
@@ -221,14 +226,14 @@ def inference_opts(self) -> "poptorch.Options":
221226 return self ._inference_opts
222227
223228 def _convert_to_poptorch_loader (
224- self , dataloader : DataLoader , sampler , mode : Optional [RunningStage ] = None
229+ self , dataloader : DataLoader , sampler : Union [ Sampler , Iterable ] , mode : Optional [RunningStage ] = None
225230 ) -> "poptorch.DataLoader" :
226231 if isinstance (dataloader , poptorch .DataLoader ):
227232 # the user is returning the `poptorch.DataLoader` directly, don't change anything.
228233 return dataloader
229234
230235 dl_args , dl_kwargs = _get_dataloader_init_args_and_kwargs (
231- dataloader , sampler , mode , self .replication_factor > 1
236+ dataloader , sampler , mode , self .replication_factor > 1 # type: ignore[arg-type]
232237 )
233238 opts = self .training_opts if mode == RunningStage .TRAINING else self .inference_opts
234239 dataloader = poptorch .DataLoader (opts , * dl_args , ** dl_kwargs )
@@ -240,6 +245,7 @@ def _handle_gradient_accumulation_steps(self) -> None:
240245
241246 ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally.
242247 """
248+ assert self .lightning_module is not None
243249 accumulation_scheduler = self .lightning_module .trainer .accumulation_scheduler
244250
245251 if accumulation_scheduler .epochs != [0 ]:
@@ -251,18 +257,19 @@ def _handle_gradient_accumulation_steps(self) -> None:
251257 accumulation_scheduler .scheduling .update ({0 : 1 })
252258
253259 @property
254- def _n_replicate (self ):
260+ def _n_replicate (self ) -> int :
261+ assert self .lightning_module is not None
255262 opts = self .training_opts if self .lightning_module .training else self .inference_opts
256263 accumulate_grad_batches = opts .Training .gradient_accumulation
257264 device_iterations = opts .device_iterations
258265 replication_factor = opts .replication_factor
259266 return replication_factor * device_iterations * accumulate_grad_batches
260267
261- def _prepare_input (self , args : Any ):
262- def to_tuple (x ) :
268+ def _prepare_input (self , args : Any ) -> Any :
269+ def to_tuple (x : Any ) -> Tuple :
263270 return tuple (x )
264271
265- def to_tensor (x ) :
272+ def to_tensor (x : Any ) -> Tensor :
266273 return torch .tensor (x ).unsqueeze (0 ).repeat (self ._n_replicate )
267274
268275 args = apply_to_collection (args , dtype = list , function = to_tuple )
@@ -281,6 +288,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat
281288
282289 def _disable_zero_grad (self ) -> None :
283290 lightning_module = self .lightning_module
291+ assert lightning_module is not None
284292 if is_overridden ("optimizer_zero_grad" , lightning_module ):
285293 assert lightning_module is not None # `is_overridden` returns False otherwise
286294 rank_zero_warn (
@@ -289,27 +297,28 @@ def _disable_zero_grad(self) -> None:
289297 )
290298 lightning_module .optimizer_zero_grad = None # type: ignore[assignment]
291299
292- def _step (self , stage : RunningStage , * args : Any , ** kwargs : Any ):
300+ def _step (self , stage : RunningStage , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
293301 args = self ._prepare_input (args )
302+ assert self .lightning_module is not None
294303 poptorch_model = self .poptorch_models [stage ]
295304 self .lightning_module ._running_torchscript = True
296305 out = poptorch_model (* args , ** kwargs )
297306 self .lightning_module ._running_torchscript = False
298307 return out
299308
300- def training_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
309+ def training_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
301310 with self .precision_plugin .train_step_context ():
302311 return self ._step (RunningStage .TRAINING , * args , ** kwargs )
303312
304- def validation_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
313+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
305314 with self .precision_plugin .val_step_context ():
306315 return self ._step (RunningStage .VALIDATING , * args , ** kwargs )
307316
308- def test_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
317+ def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
309318 with self .precision_plugin .test_step_context ():
310319 return self ._step (RunningStage .TESTING , * args , ** kwargs )
311320
312- def predict_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
321+ def predict_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
313322 with self .precision_plugin .predict_step_context ():
314323 return self ._step (RunningStage .PREDICTING , * args , ** kwargs )
315324
@@ -318,26 +327,27 @@ def teardown(self) -> None:
318327 # undo dataloader patching
319328 pl .trainer .connectors .data_connector ._update_dataloader = self ._update_dataloader_original
320329
330+ assert self .lightning_module is not None
321331 if self ._optimizer_zero_grad_original is not None :
322332 # re-enable `optimizer_zero_grad`
323- self .lightning_module .optimizer_zero_grad = self ._optimizer_zero_grad_original
333+ self .lightning_module .optimizer_zero_grad = self ._optimizer_zero_grad_original # type: ignore[assignment]
324334
325335 for model in self .poptorch_models .values ():
326336 model .destroy ()
327337
328338 super ().teardown ()
329339
330- def _compiled (self , model : Any ):
340+ def _compiled (self , model : Any ) -> bool :
331341 # Required to ensure we only attach compiled models, as they are compiled lazily.
332342 return model ._executable is not None
333343
334- def _detach_models (self ):
344+ def _detach_models (self ) -> None :
335345 """Detaches all stage specific models from IPU devices."""
336346 for k , model in self .poptorch_models .items ():
337347 if self ._compiled (model ) and model .isAttachedToDevice ():
338348 model .detachFromDevice ()
339349
340- def _load_model (self , stage : str ) :
350+ def _load_model (self , stage : RunningStage ) -> None :
341351 """Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices.
342352
343353 Args:
@@ -348,28 +358,28 @@ def _load_model(self, stage: str):
348358 if self ._compiled (model ) and not model .isAttachedToDevice ():
349359 model .attachToDevice ()
350360
351- def on_train_start (self ):
361+ def on_train_start (self ) -> None :
352362 self ._load_model (RunningStage .TRAINING )
353363
354- def on_validation_start (self ):
364+ def on_validation_start (self ) -> None :
355365 self ._load_model (RunningStage .VALIDATING )
356366
357- def on_test_start (self ):
367+ def on_test_start (self ) -> None :
358368 self ._load_model (RunningStage .TESTING )
359369
360- def on_predict_start (self ):
370+ def on_predict_start (self ) -> None :
361371 self ._load_model (RunningStage .PREDICTING )
362372
363- def on_train_end (self ):
373+ def on_train_end (self ) -> None :
364374 self ._detach_models ()
365375
366- def on_validation_end (self ):
376+ def on_validation_end (self ) -> None :
367377 self ._detach_models ()
368378
369- def on_test_end (self ):
379+ def on_test_end (self ) -> None :
370380 self ._detach_models ()
371381
372- def on_predict_end (self ):
382+ def on_predict_end (self ) -> None :
373383 self ._detach_models ()
374384
375385 def on_train_batch_start (self , batch : Any , batch_idx : int ) -> None :
@@ -397,7 +407,7 @@ def barrier(self, name: Optional[str] = None) -> None:
397407 def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
398408 return tensor
399409
400- def broadcast (self , obj : object , src : int = 0 ) -> object :
410+ def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
401411 return obj
402412
403413 @classmethod
0 commit comments