Skip to content

Commit e293749

Browse files
HalestormAIotajawaelchlirohitgr7
authored
Fix typing annotations for the ipu strategy (#13786)
Co-authored-by: otaj <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent d748dae commit e293749

File tree

2 files changed

+39
-30
lines changed

2 files changed

+39
-30
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ module = [
6363
"pytorch_lightning.profilers.simple",
6464
"pytorch_lightning.strategies.ddp",
6565
"pytorch_lightning.strategies.fully_sharded",
66-
"pytorch_lightning.strategies.ipu",
6766
"pytorch_lightning.strategies.sharded",
6867
"pytorch_lightning.strategies.sharded_spawn",
6968
"pytorch_lightning.trainer.callback_hook",

src/pytorch_lightning/strategies/ipu.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,19 @@
1313
# limitations under the License.
1414
import json
1515
import os
16-
from typing import Any, Callable, Dict, List, Optional, Union
16+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
1717

1818
import torch
1919
from torch import FloatTensor, Tensor
20-
from torch.utils.data import DataLoader
20+
from torch.utils.data import DataLoader, Sampler
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
2424
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2525
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2626
from pytorch_lightning.plugins.precision import PrecisionPlugin
2727
from pytorch_lightning.strategies.parallel import ParallelStrategy
28+
from pytorch_lightning.strategies.strategy import TBroadcast
2829
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
2930
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
3031
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -112,12 +113,12 @@ def __init__(
112113
self.device_iterations = device_iterations
113114
self.autoreport = autoreport
114115
self.autoreport_dir = autoreport_dir
115-
self.poptorch_models = {}
116+
self.poptorch_models: Dict[RunningStage, "poptorch.PoplarExecutor"] = {}
116117
self._training_opts = training_opts
117118
self._inference_opts = inference_opts
118119

119120
if self.autoreport:
120-
options = {"autoReport.all": self.autoreport}
121+
options: Dict[str, Any] = {"autoReport.all": self.autoreport}
121122
if self.autoreport_dir:
122123
self._fs = get_filesystem(str(self.autoreport_dir))
123124
self._fs.makedirs(self.autoreport_dir, exist_ok=True)
@@ -139,6 +140,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
139140

140141
super().setup(trainer)
141142

143+
assert self.lightning_module is not None
144+
142145
# disable the `optimizer_zero_grad` function by setting it to `None`.
143146
# this is because the IPU zeros the gradients internally
144147
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
@@ -192,12 +195,14 @@ def replication_factor(self) -> int:
192195
if self._inference_opts:
193196
return self._inference_opts.replication_factor
194197

198+
assert self.parallel_devices
195199
return len(self.parallel_devices)
196-
197200
stage = self.lightning_module.trainer.state.stage
201+
assert stage is not None
198202
return self.poptorch_models[stage]._options.toDict()["replication_factor"]
199203

200204
def _create_opts(self, training: bool) -> "poptorch.Options":
205+
assert self.lightning_module is not None
201206
opts = poptorch.Options()
202207
opts.deviceIterations(self.device_iterations)
203208
opts.replicationFactor(self.replication_factor)
@@ -221,14 +226,14 @@ def inference_opts(self) -> "poptorch.Options":
221226
return self._inference_opts
222227

223228
def _convert_to_poptorch_loader(
224-
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
229+
self, dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
225230
) -> "poptorch.DataLoader":
226231
if isinstance(dataloader, poptorch.DataLoader):
227232
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
228233
return dataloader
229234

230235
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
231-
dataloader, sampler, mode, self.replication_factor > 1
236+
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
232237
)
233238
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
234239
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
@@ -240,6 +245,7 @@ def _handle_gradient_accumulation_steps(self) -> None:
240245
241246
``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally.
242247
"""
248+
assert self.lightning_module is not None
243249
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler
244250

245251
if accumulation_scheduler.epochs != [0]:
@@ -251,18 +257,19 @@ def _handle_gradient_accumulation_steps(self) -> None:
251257
accumulation_scheduler.scheduling.update({0: 1})
252258

253259
@property
254-
def _n_replicate(self):
260+
def _n_replicate(self) -> int:
261+
assert self.lightning_module is not None
255262
opts = self.training_opts if self.lightning_module.training else self.inference_opts
256263
accumulate_grad_batches = opts.Training.gradient_accumulation
257264
device_iterations = opts.device_iterations
258265
replication_factor = opts.replication_factor
259266
return replication_factor * device_iterations * accumulate_grad_batches
260267

261-
def _prepare_input(self, args: Any):
262-
def to_tuple(x):
268+
def _prepare_input(self, args: Any) -> Any:
269+
def to_tuple(x: Any) -> Tuple:
263270
return tuple(x)
264271

265-
def to_tensor(x):
272+
def to_tensor(x: Any) -> Tensor:
266273
return torch.tensor(x).unsqueeze(0).repeat(self._n_replicate)
267274

268275
args = apply_to_collection(args, dtype=list, function=to_tuple)
@@ -281,6 +288,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat
281288

282289
def _disable_zero_grad(self) -> None:
283290
lightning_module = self.lightning_module
291+
assert lightning_module is not None
284292
if is_overridden("optimizer_zero_grad", lightning_module):
285293
assert lightning_module is not None # `is_overridden` returns False otherwise
286294
rank_zero_warn(
@@ -289,27 +297,28 @@ def _disable_zero_grad(self) -> None:
289297
)
290298
lightning_module.optimizer_zero_grad = None # type: ignore[assignment]
291299

292-
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
300+
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
293301
args = self._prepare_input(args)
302+
assert self.lightning_module is not None
294303
poptorch_model = self.poptorch_models[stage]
295304
self.lightning_module._running_torchscript = True
296305
out = poptorch_model(*args, **kwargs)
297306
self.lightning_module._running_torchscript = False
298307
return out
299308

300-
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
309+
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
301310
with self.precision_plugin.train_step_context():
302311
return self._step(RunningStage.TRAINING, *args, **kwargs)
303312

304-
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
313+
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
305314
with self.precision_plugin.val_step_context():
306315
return self._step(RunningStage.VALIDATING, *args, **kwargs)
307316

308-
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
317+
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
309318
with self.precision_plugin.test_step_context():
310319
return self._step(RunningStage.TESTING, *args, **kwargs)
311320

312-
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
321+
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
313322
with self.precision_plugin.predict_step_context():
314323
return self._step(RunningStage.PREDICTING, *args, **kwargs)
315324

@@ -318,26 +327,27 @@ def teardown(self) -> None:
318327
# undo dataloader patching
319328
pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original
320329

330+
assert self.lightning_module is not None
321331
if self._optimizer_zero_grad_original is not None:
322332
# re-enable `optimizer_zero_grad`
323-
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original
333+
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment]
324334

325335
for model in self.poptorch_models.values():
326336
model.destroy()
327337

328338
super().teardown()
329339

330-
def _compiled(self, model: Any):
340+
def _compiled(self, model: Any) -> bool:
331341
# Required to ensure we only attach compiled models, as they are compiled lazily.
332342
return model._executable is not None
333343

334-
def _detach_models(self):
344+
def _detach_models(self) -> None:
335345
"""Detaches all stage specific models from IPU devices."""
336346
for k, model in self.poptorch_models.items():
337347
if self._compiled(model) and model.isAttachedToDevice():
338348
model.detachFromDevice()
339349

340-
def _load_model(self, stage: str):
350+
def _load_model(self, stage: RunningStage) -> None:
341351
"""Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices.
342352
343353
Args:
@@ -348,28 +358,28 @@ def _load_model(self, stage: str):
348358
if self._compiled(model) and not model.isAttachedToDevice():
349359
model.attachToDevice()
350360

351-
def on_train_start(self):
361+
def on_train_start(self) -> None:
352362
self._load_model(RunningStage.TRAINING)
353363

354-
def on_validation_start(self):
364+
def on_validation_start(self) -> None:
355365
self._load_model(RunningStage.VALIDATING)
356366

357-
def on_test_start(self):
367+
def on_test_start(self) -> None:
358368
self._load_model(RunningStage.TESTING)
359369

360-
def on_predict_start(self):
370+
def on_predict_start(self) -> None:
361371
self._load_model(RunningStage.PREDICTING)
362372

363-
def on_train_end(self):
373+
def on_train_end(self) -> None:
364374
self._detach_models()
365375

366-
def on_validation_end(self):
376+
def on_validation_end(self) -> None:
367377
self._detach_models()
368378

369-
def on_test_end(self):
379+
def on_test_end(self) -> None:
370380
self._detach_models()
371381

372-
def on_predict_end(self):
382+
def on_predict_end(self) -> None:
373383
self._detach_models()
374384

375385
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
@@ -397,7 +407,7 @@ def barrier(self, name: Optional[str] = None) -> None:
397407
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
398408
return tensor
399409

400-
def broadcast(self, obj: object, src: int = 0) -> object:
410+
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
401411
return obj
402412

403413
@classmethod

0 commit comments

Comments
 (0)