Skip to content

Commit 96433d0

Browse files
Sean Narenkaushikb11tchaton
authored
IPU Integration 5/5 (#7867)
* Initial changes * Add broken example for now * Fix reference * Fix format * Code runs * Fixes * Clear up files * Add tests, helpers, fixes * Small cleanups * Refactors based on review * Swap to special tests * Add special tests * Add source * Cleanups * Add logic to attach/detach model from devices * Fixes for tests * Fixes for tests * Move earlier * Cleanups * Add check for nvcc * Add tests, cleanups * Fix errors * fix * Try condition * Add missing annotation * Clearer * Clearer message * Fix variable * Cleanups * Add comment * CHANGELOG.md * Add simple selection test * Remove special=True to see what happens * Fix test * Update tests/accelerators/test_ipu.py Co-authored-by: Kaushik B <[email protected]> * Convert ipu_cores -> ipus * Add typing, fail earlier * simplify precision * Add test, add helper * fix accum * Update pytorch_lightning/plugins/training_type/ipu.py Co-authored-by: thomas chaton <[email protected]> * Use stages * Make sure warning message returned * thorw error * Add more tests, use fs * add comment * Clean * Address feedback, add IPU tests * Fixes * Fix signature * Add types * Remove autoround * Add docstring * ipu_cores -> ipus * Add test, remove unnecessary precision set * Add optimizer test * Add precision back with test * Address code review * Change to probs * Move some of the asserts earlier Co-authored-by: Kaushik B <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 42c7f27 commit 96433d0

File tree

15 files changed

+1150
-5
lines changed

15 files changed

+1150
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7171
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
7272

7373

74+
- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))
75+
76+
7477
- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
7578

7679

pl_examples/ipu_examples/__init__.py

Whitespace-only changes.

pl_examples/ipu_examples/mnist.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from torch.nn import functional as F
17+
18+
import pytorch_lightning as pl
19+
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
20+
21+
22+
class LitClassifier(pl.LightningModule):
23+
24+
def __init__(
25+
self,
26+
hidden_dim: int = 128,
27+
learning_rate: float = 0.0001,
28+
):
29+
super().__init__()
30+
self.save_hyperparameters()
31+
32+
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
33+
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
34+
35+
def forward(self, x):
36+
x = x.view(x.size(0), -1)
37+
x = torch.relu(self.l1(x))
38+
x = torch.relu(self.l2(x))
39+
return x
40+
41+
def training_step(self, batch, batch_idx):
42+
x, y = batch
43+
y_hat = self(x)
44+
loss = F.cross_entropy(y_hat, y)
45+
return loss
46+
47+
def validation_step(self, batch, batch_idx):
48+
x, y = batch
49+
probs = self(x)
50+
# we currently return the accuracy as the validation_step/test_step is run on the IPU devices.
51+
# Outputs from the step functions are sent to the host device, where we calculate the metrics in
52+
# validation_epoch_end and test_epoch_end for the test_step.
53+
acc = self.accuracy(probs, y)
54+
return acc
55+
56+
def test_step(self, batch, batch_idx):
57+
x, y = batch
58+
logits = self(x)
59+
acc = self.accuracy(logits, y)
60+
return acc
61+
62+
def accuracy(self, logits, y):
63+
# currently IPU poptorch doesn't implicit convert bools to tensor
64+
# hence we use an explicit calculation for accuracy here. Once fixed in poptorch
65+
# we can use the accuracy metric.
66+
acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
67+
return acc
68+
69+
def validation_epoch_end(self, outputs) -> None:
70+
# since the training step/validation step and test step are run on the IPU device
71+
# we must log the average loss outside the step functions.
72+
self.log('val_acc', torch.stack(outputs).mean(), prog_bar=True)
73+
74+
def test_epoch_end(self, outputs) -> None:
75+
self.log('test_acc', torch.stack(outputs).mean())
76+
77+
def configure_optimizers(self):
78+
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
79+
80+
81+
if __name__ == '__main__':
82+
dm = MNISTDataModule(batch_size=32)
83+
84+
model = LitClassifier()
85+
86+
trainer = pl.Trainer(max_epochs=2, ipus=8)
87+
88+
trainer.fit(model, datamodule=dm)
89+
trainer.test(model, datamodule=dm)

pytorch_lightning/accelerators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401
1414
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401
1515
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401
16+
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa F401
1617
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections import Callable
15+
from typing import Any
16+
17+
from torch.optim import Optimizer
18+
19+
import pytorch_lightning as pl
20+
from pytorch_lightning.accelerators.accelerator import Accelerator
21+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22+
23+
24+
class IPUAccelerator(Accelerator):
25+
""" Accelerator for IPUs. """
26+
27+
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
28+
super().setup_optimizers(trainer)
29+
30+
if len(self.optimizers) > 1:
31+
raise MisconfigurationException("IPUs currently only support one optimizer.")
32+
33+
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
34+
# Optimizer step is handled by the IPU accelerator.
35+
lambda_closure()

pytorch_lightning/plugins/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
1010
FullyShardedNativeMixedPrecisionPlugin,
1111
)
12+
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
1213
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
1314
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
1415
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
@@ -20,6 +21,7 @@
2021
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
2122
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
2223
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
24+
from pytorch_lightning.plugins.training_type.ipu import IPUPlugin # noqa: F401
2325
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
2426
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
2527
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
@@ -41,6 +43,8 @@
4143
"DeepSpeedPrecisionPlugin",
4244
"DoublePrecisionPlugin",
4345
"HorovodPlugin",
46+
"IPUPlugin",
47+
"IPUPrecisionPlugin",
4448
"NativeMixedPrecisionPlugin",
4549
"PrecisionPlugin",
4650
"ShardedNativeMixedPrecisionPlugin",
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional, Union
15+
16+
from torch import Tensor
17+
from torch.nn import Module
18+
from torch.optim import Optimizer
19+
20+
import pytorch_lightning as pl
21+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
22+
from pytorch_lightning.utilities import GradClipAlgorithmType
23+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
24+
25+
26+
class IPUPrecisionPlugin(PrecisionPlugin):
27+
28+
def __init__(self, precision: int) -> None:
29+
super().__init__()
30+
self.precision = precision
31+
32+
def backward(
33+
self,
34+
model: 'pl.LightningModule',
35+
closure_loss: Tensor,
36+
optimizer: Optimizer,
37+
opt_idx: int,
38+
should_accumulate: bool,
39+
*args: Any,
40+
**kwargs: Any,
41+
) -> Tensor:
42+
# IPU internally manages bwd step.
43+
return closure_loss
44+
45+
def clip_gradients(
46+
self,
47+
optimizer: Optimizer,
48+
clip_val: Union[int, float],
49+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
50+
model: Optional[Module] = None
51+
) -> None:
52+
"""Clips the gradients"""
53+
if clip_val is None:
54+
return
55+
56+
clip_val = float(clip_val)
57+
if clip_val <= 0:
58+
return
59+
60+
raise MisconfigurationException("IPUs currently do not support clipping gradients.")

0 commit comments

Comments
 (0)