diff --git a/pyproject.toml b/pyproject.toml index 15f0293bb1c8a..2df0142e9af4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,6 @@ module = [ "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.ddp", "pytorch_lightning.strategies.fully_sharded", - "pytorch_lightning.strategies.ipu", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.trainer.callback_hook", diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 82ba4ad227f7c..0b5d8e835ad1d 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -13,11 +13,11 @@ # limitations under the License. import json import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import FloatTensor, Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Sampler import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @@ -25,6 +25,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -112,12 +113,12 @@ def __init__( self.device_iterations = device_iterations self.autoreport = autoreport self.autoreport_dir = autoreport_dir - self.poptorch_models = {} + self.poptorch_models: Dict[RunningStage, "poptorch.PoplarExecutor"] = {} self._training_opts = training_opts self._inference_opts = inference_opts if self.autoreport: - options = {"autoReport.all": self.autoreport} + options: Dict[str, Any] = {"autoReport.all": self.autoreport} if self.autoreport_dir: self._fs = get_filesystem(str(self.autoreport_dir)) self._fs.makedirs(self.autoreport_dir, exist_ok=True) @@ -139,6 +140,8 @@ def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) + assert self.lightning_module is not None + # disable the `optimizer_zero_grad` function by setting it to `None`. # this is because the IPU zeros the gradients internally self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad @@ -192,12 +195,14 @@ def replication_factor(self) -> int: if self._inference_opts: return self._inference_opts.replication_factor + assert self.parallel_devices return len(self.parallel_devices) - stage = self.lightning_module.trainer.state.stage + assert stage is not None return self.poptorch_models[stage]._options.toDict()["replication_factor"] def _create_opts(self, training: bool) -> "poptorch.Options": + assert self.lightning_module is not None opts = poptorch.Options() opts.deviceIterations(self.device_iterations) opts.replicationFactor(self.replication_factor) @@ -221,14 +226,14 @@ def inference_opts(self) -> "poptorch.Options": return self._inference_opts def _convert_to_poptorch_loader( - self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None + self, dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> "poptorch.DataLoader": if isinstance(dataloader, poptorch.DataLoader): # the user is returning the `poptorch.DataLoader` directly, don't change anything. return dataloader dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs( - dataloader, sampler, mode, self.replication_factor > 1 + dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type] ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) @@ -240,6 +245,7 @@ def _handle_gradient_accumulation_steps(self) -> None: ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally. """ + assert self.lightning_module is not None accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: @@ -251,18 +257,19 @@ def _handle_gradient_accumulation_steps(self) -> None: accumulation_scheduler.scheduling.update({0: 1}) @property - def _n_replicate(self): + def _n_replicate(self) -> int: + assert self.lightning_module is not None opts = self.training_opts if self.lightning_module.training else self.inference_opts accumulate_grad_batches = opts.Training.gradient_accumulation device_iterations = opts.device_iterations replication_factor = opts.replication_factor return replication_factor * device_iterations * accumulate_grad_batches - def _prepare_input(self, args: Any): - def to_tuple(x): + def _prepare_input(self, args: Any) -> Any: + def to_tuple(x: Any) -> Tuple: return tuple(x) - def to_tensor(x): + def to_tensor(x: Any) -> Tensor: return torch.tensor(x).unsqueeze(0).repeat(self._n_replicate) args = apply_to_collection(args, dtype=list, function=to_tuple) @@ -281,6 +288,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat def _disable_zero_grad(self) -> None: lightning_module = self.lightning_module + assert lightning_module is not None if is_overridden("optimizer_zero_grad", lightning_module): assert lightning_module is not None # `is_overridden` returns False otherwise rank_zero_warn( @@ -289,27 +297,28 @@ def _disable_zero_grad(self) -> None: ) lightning_module.optimizer_zero_grad = None # type: ignore[assignment] - def _step(self, stage: RunningStage, *args: Any, **kwargs: Any): + def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT: args = self._prepare_input(args) + assert self.lightning_module is not None poptorch_model = self.poptorch_models[stage] self.lightning_module._running_torchscript = True out = poptorch_model(*args, **kwargs) self.lightning_module._running_torchscript = False return out - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.train_step_context(): return self._step(RunningStage.TRAINING, *args, **kwargs) - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): return self._step(RunningStage.VALIDATING, *args, **kwargs) - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.test_step_context(): return self._step(RunningStage.TESTING, *args, **kwargs) - def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): return self._step(RunningStage.PREDICTING, *args, **kwargs) @@ -318,26 +327,27 @@ def teardown(self) -> None: # undo dataloader patching pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original + assert self.lightning_module is not None if self._optimizer_zero_grad_original is not None: # re-enable `optimizer_zero_grad` - self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original + self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment] for model in self.poptorch_models.values(): model.destroy() super().teardown() - def _compiled(self, model: Any): + def _compiled(self, model: Any) -> bool: # Required to ensure we only attach compiled models, as they are compiled lazily. return model._executable is not None - def _detach_models(self): + def _detach_models(self) -> None: """Detaches all stage specific models from IPU devices.""" for k, model in self.poptorch_models.items(): if self._compiled(model) and model.isAttachedToDevice(): model.detachFromDevice() - def _load_model(self, stage: str): + def _load_model(self, stage: RunningStage) -> None: """Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices. Args: @@ -348,28 +358,28 @@ def _load_model(self, stage: str): if self._compiled(model) and not model.isAttachedToDevice(): model.attachToDevice() - def on_train_start(self): + def on_train_start(self) -> None: self._load_model(RunningStage.TRAINING) - def on_validation_start(self): + def on_validation_start(self) -> None: self._load_model(RunningStage.VALIDATING) - def on_test_start(self): + def on_test_start(self) -> None: self._load_model(RunningStage.TESTING) - def on_predict_start(self): + def on_predict_start(self) -> None: self._load_model(RunningStage.PREDICTING) - def on_train_end(self): + def on_train_end(self) -> None: self._detach_models() - def on_validation_end(self): + def on_validation_end(self) -> None: self._detach_models() - def on_test_end(self): + def on_test_end(self) -> None: self._detach_models() - def on_predict_end(self): + def on_predict_end(self) -> None: self._detach_models() def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: @@ -397,7 +407,7 @@ def barrier(self, name: Optional[str] = None) -> None: def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: return tensor - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj @classmethod