Skip to content

Commit e8beceb

Browse files
carmoccarohitgr7awaelchli
authored
Add TPUPrecisionPlugin (#10020)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 4aaca17 commit e8beceb

File tree

11 files changed

+79
-17
lines changed

11 files changed

+79
-17
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
196196
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
197197

198198

199+
- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))
200+
201+
199202
- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))
200203

201204

docs/source/api_references.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,16 @@ Precision Plugins
170170
:template: classtemplate.rst
171171

172172
PrecisionPlugin
173+
MixedPrecisionPlugin
173174
NativeMixedPrecisionPlugin
174175
ShardedNativeMixedPrecisionPlugin
175176
ApexMixedPrecisionPlugin
176177
DeepSpeedPrecisionPlugin
178+
TPUPrecisionPlugin
177179
TPUHalfPrecisionPlugin
178180
DoublePrecisionPlugin
181+
FullyShardedNativeMixedPrecisionPlugin
182+
IPUPrecisionPlugin
179183

180184
Cluster Environments
181185
^^^^^^^^^^^^^^^^^^^^

docs/source/extensions/plugins.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,16 @@ Precision Plugins
131131
:template: classtemplate.rst
132132

133133
PrecisionPlugin
134+
MixedPrecisionPlugin
134135
NativeMixedPrecisionPlugin
135136
ShardedNativeMixedPrecisionPlugin
136137
ApexMixedPrecisionPlugin
137138
DeepSpeedPrecisionPlugin
139+
TPUPrecisionPlugin
138140
TPUHalfPrecisionPlugin
139141
DoublePrecisionPlugin
142+
FullyShardedNativeMixedPrecisionPlugin
143+
IPUPrecisionPlugin
140144

141145

142146
Cluster Environments

pytorch_lightning/accelerators/tpu.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818

1919
import pytorch_lightning as pl
2020
from pytorch_lightning.accelerators.accelerator import Accelerator
21-
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
21+
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
2222
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
2323
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
2424
from pytorch_lightning.utilities import _XLA_AVAILABLE
2525
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
26-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2726

2827
if _XLA_AVAILABLE:
2928
import torch_xla.core.xla_model as xm
@@ -35,18 +34,19 @@ class TPUAccelerator(Accelerator):
3534
def setup(self, trainer: "pl.Trainer") -> None:
3635
"""
3736
Raises:
38-
MisconfigurationException:
39-
If AMP is used with TPU.
40-
MisconfigurationException:
41-
If TPUs are not using a single TPU core or TPU spawn training.
37+
ValueError:
38+
If the precision or training type plugin are unsupported.
4239
"""
43-
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
44-
raise MisconfigurationException(
45-
"amp + tpu is not supported. Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
40+
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
41+
# this configuration should have been avoided in the accelerator connector
42+
raise ValueError(
43+
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
4644
)
47-
4845
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
49-
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
46+
raise ValueError(
47+
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
48+
f" found {self.training_type_plugin}."
49+
)
5050
return super().setup(trainer)
5151

5252
def run_optimizer_step(

pytorch_lightning/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
1717
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
1818
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
19+
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
1920
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin
2021
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
2122
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
@@ -57,6 +58,7 @@
5758
"FullyShardedNativeMixedPrecisionPlugin",
5859
"SingleDevicePlugin",
5960
"SingleTPUPlugin",
61+
"TPUPrecisionPlugin",
6062
"TPUHalfPrecisionPlugin",
6163
"TPUSpawnPlugin",
6264
"TrainingTypePlugin",

pytorch_lightning/plugins/precision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
55
FullyShardedNativeMixedPrecisionPlugin,
66
)
7+
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
78
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
89
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
910
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
1011
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
12+
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
1113
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
15+
16+
17+
class TPUPrecisionPlugin(PrecisionPlugin):
18+
...

pytorch_lightning/plugins/precision/tpu_bfloat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import torch.nn as nn
1818
from torch.optim import Optimizer
1919

20-
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
20+
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
2121

2222

23-
class TPUHalfPrecisionPlugin(PrecisionPlugin):
23+
class TPUHalfPrecisionPlugin(TPUPrecisionPlugin):
2424
"""Plugin that enables bfloats on TPUs."""
2525

2626
precision: int = 16

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
SingleDevicePlugin,
4848
SingleTPUPlugin,
4949
TPUHalfPrecisionPlugin,
50+
TPUPrecisionPlugin,
5051
TPUSpawnPlugin,
5152
TrainingTypePlugin,
5253
TrainingTypePluginsRegistry,
@@ -592,6 +593,17 @@ def select_precision_plugin(self) -> PrecisionPlugin:
592593

593594
if self.use_ipu:
594595
return IPUPrecisionPlugin(self.precision)
596+
if self.use_tpu:
597+
if self.precision == 32:
598+
return TPUPrecisionPlugin()
599+
elif self.precision == 64:
600+
raise MisconfigurationException(
601+
"`Trainer(accelerator='tpu', precision=64)` is not implemented."
602+
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
603+
" requesting this feature."
604+
)
605+
elif self.precision in (16, "bf16"):
606+
return TPUHalfPrecisionPlugin()
595607

596608
if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
597609
return DeepSpeedPrecisionPlugin(self.precision)
@@ -601,9 +613,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
601613
if self.precision == 64:
602614
return DoublePrecisionPlugin()
603615
if self.precision in (16, "bf16"):
604-
if self.use_tpu:
605-
return TPUHalfPrecisionPlugin()
606-
607616
if self.amp_type == AMPType.NATIVE:
608617
if self.amp_level is not None:
609618
raise MisconfigurationException(

tests/accelerators/test_accelerator_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,3 +976,13 @@ def on_fit_start(self, trainer, pl_module):
976976

977977
with pytest.raises(SystemExit):
978978
trainer.fit(model)
979+
980+
981+
def test_unsupported_tpu_choice(monkeypatch):
982+
import pytorch_lightning.utilities.imports as imports
983+
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
984+
985+
monkeypatch.setattr(imports, "_XLA_AVAILABLE", True)
986+
monkeypatch.setattr(AcceleratorConnector, "has_tpu", True)
987+
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
988+
Trainer(accelerator="tpu", precision=64)

0 commit comments

Comments
 (0)