Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
52a1c57
Fixed all mypy typing errors for the IPU strategy.
Jul 19, 2022
a6d5593
Remove extra typing check on the trainer validation
Jul 19, 2022
e994905
Fix circular dependency from unquoted type hint
Jul 20, 2022
36e48be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2022
44bb884
Fix issues that came up after rebasing master
Jul 21, 2022
1471b5a
Merge branch 'master' into tests/13445_ipu_strategy_typing_annotations
akihironitta Jul 23, 2022
34b1fec
Merge branch 'master' into tests/13445_ipu_strategy_typing_annotations
otaj Jul 27, 2022
587477d
Still not quite working - down to the last error on unwrapping return…
Jul 29, 2022
f86a39a
merge master
Aug 2, 2022
d3568bb
fix mypy errors
Aug 2, 2022
35a4dc5
surgical asserts
Aug 2, 2022
6e39f90
Merge branch 'master' into tests/13445_ipu_strategy_typing_annotations
HalestormAI Aug 2, 2022
199a69b
Change stage from str to RunningStage and simplify dict definition fo…
Aug 2, 2022
88a2efd
Update src/pytorch_lightning/strategies/ipu.py
awaelchli Aug 3, 2022
ed96cef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2022
f1f0061
Remove unnecessary asserts and the lightning_module property as no lo…
Aug 3, 2022
158b125
Merge branch 'master' into tests/13445_ipu_strategy_typing_annotations
HalestormAI Aug 3, 2022
3e90ae1
Remove unused input after taking out lightning_module
Aug 3, 2022
bc6f314
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2022
d9b8eef
Switch conditionals for assertions
HalestormAI Aug 3, 2022
6cf02f0
Merge branch 'master' into tests/13445_ipu_strategy_typing_annotations
HalestormAI Aug 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ module = [
"pytorch_lightning.profilers.simple",
"pytorch_lightning.strategies.ddp",
"pytorch_lightning.strategies.fully_sharded",
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.trainer.callback_hook",
Expand Down
68 changes: 39 additions & 29 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
# limitations under the License.
import json
import os
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import FloatTensor, Tensor
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Sampler

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -112,12 +113,12 @@ def __init__(
self.device_iterations = device_iterations
self.autoreport = autoreport
self.autoreport_dir = autoreport_dir
self.poptorch_models = {}
self.poptorch_models: Dict[RunningStage, "poptorch.PoplarExecutor"] = {}
self._training_opts = training_opts
self._inference_opts = inference_opts

if self.autoreport:
options = {"autoReport.all": self.autoreport}
options: Dict[str, Any] = {"autoReport.all": self.autoreport}
if self.autoreport_dir:
self._fs = get_filesystem(str(self.autoreport_dir))
self._fs.makedirs(self.autoreport_dir, exist_ok=True)
Expand All @@ -139,6 +140,8 @@ def setup(self, trainer: "pl.Trainer") -> None:

super().setup(trainer)

assert self.lightning_module is not None

# disable the `optimizer_zero_grad` function by setting it to `None`.
# this is because the IPU zeros the gradients internally
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
Expand Down Expand Up @@ -192,12 +195,14 @@ def replication_factor(self) -> int:
if self._inference_opts:
return self._inference_opts.replication_factor

assert self.parallel_devices
return len(self.parallel_devices)

stage = self.lightning_module.trainer.state.stage
assert stage is not None
return self.poptorch_models[stage]._options.toDict()["replication_factor"]

def _create_opts(self, training: bool) -> "poptorch.Options":
assert self.lightning_module is not None
opts = poptorch.Options()
opts.deviceIterations(self.device_iterations)
opts.replicationFactor(self.replication_factor)
Expand All @@ -221,14 +226,14 @@ def inference_opts(self) -> "poptorch.Options":
return self._inference_opts

def _convert_to_poptorch_loader(
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
self, dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
if isinstance(dataloader, poptorch.DataLoader):
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader

dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
Expand All @@ -240,6 +245,7 @@ def _handle_gradient_accumulation_steps(self) -> None:

``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally.
"""
assert self.lightning_module is not None
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
Expand All @@ -251,18 +257,19 @@ def _handle_gradient_accumulation_steps(self) -> None:
accumulation_scheduler.scheduling.update({0: 1})

@property
def _n_replicate(self):
def _n_replicate(self) -> int:
assert self.lightning_module is not None
opts = self.training_opts if self.lightning_module.training else self.inference_opts
accumulate_grad_batches = opts.Training.gradient_accumulation
device_iterations = opts.device_iterations
replication_factor = opts.replication_factor
return replication_factor * device_iterations * accumulate_grad_batches

def _prepare_input(self, args: Any):
def to_tuple(x):
def _prepare_input(self, args: Any) -> Any:
def to_tuple(x: Any) -> Tuple:
return tuple(x)

def to_tensor(x):
def to_tensor(x: Any) -> Tensor:
return torch.tensor(x).unsqueeze(0).repeat(self._n_replicate)

args = apply_to_collection(args, dtype=list, function=to_tuple)
Expand All @@ -281,6 +288,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat

def _disable_zero_grad(self) -> None:
lightning_module = self.lightning_module
assert lightning_module is not None
if is_overridden("optimizer_zero_grad", lightning_module):
assert lightning_module is not None # `is_overridden` returns False otherwise
rank_zero_warn(
Expand All @@ -289,27 +297,28 @@ def _disable_zero_grad(self) -> None:
)
lightning_module.optimizer_zero_grad = None # type: ignore[assignment]

def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
args = self._prepare_input(args)
assert self.lightning_module is not None
poptorch_model = self.poptorch_models[stage]
self.lightning_module._running_torchscript = True
out = poptorch_model(*args, **kwargs)
self.lightning_module._running_torchscript = False
return out

def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self._step(RunningStage.TRAINING, *args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self._step(RunningStage.VALIDATING, *args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self._step(RunningStage.TESTING, *args, **kwargs)

def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self._step(RunningStage.PREDICTING, *args, **kwargs)

Expand All @@ -318,26 +327,27 @@ def teardown(self) -> None:
# undo dataloader patching
pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original

assert self.lightning_module is not None
if self._optimizer_zero_grad_original is not None:
# re-enable `optimizer_zero_grad`
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment]

for model in self.poptorch_models.values():
model.destroy()

super().teardown()

def _compiled(self, model: Any):
def _compiled(self, model: Any) -> bool:
# Required to ensure we only attach compiled models, as they are compiled lazily.
return model._executable is not None

def _detach_models(self):
def _detach_models(self) -> None:
"""Detaches all stage specific models from IPU devices."""
for k, model in self.poptorch_models.items():
if self._compiled(model) and model.isAttachedToDevice():
model.detachFromDevice()

def _load_model(self, stage: str):
def _load_model(self, stage: RunningStage) -> None:
"""Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices.

Args:
Expand All @@ -348,28 +358,28 @@ def _load_model(self, stage: str):
if self._compiled(model) and not model.isAttachedToDevice():
model.attachToDevice()

def on_train_start(self):
def on_train_start(self) -> None:
self._load_model(RunningStage.TRAINING)

def on_validation_start(self):
def on_validation_start(self) -> None:
self._load_model(RunningStage.VALIDATING)

def on_test_start(self):
def on_test_start(self) -> None:
self._load_model(RunningStage.TESTING)

def on_predict_start(self):
def on_predict_start(self) -> None:
self._load_model(RunningStage.PREDICTING)

def on_train_end(self):
def on_train_end(self) -> None:
self._detach_models()

def on_validation_end(self):
def on_validation_end(self) -> None:
self._detach_models()

def on_test_end(self):
def on_test_end(self) -> None:
self._detach_models()

def on_predict_end(self):
def on_predict_end(self) -> None:
self._detach_models()

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
Expand Down Expand Up @@ -397,7 +407,7 @@ def barrier(self, name: Optional[str] = None) -> None:
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
return tensor

def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj

@classmethod
Expand Down