Skip to content

Commit 01f26b4

Browse files
committed
tpu
1 parent 0714933 commit 01f26b4

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

pytorch_lightning/accelerators/tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class TPUAccelerator(Accelerator):
1919
def setup(self, trainer, model):
2020
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
2121
raise MisconfigurationException(
22-
"amp + tpu is not supported. "
23-
"Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
22+
"amp + tpu is not supported."
23+
" Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
2424
)
2525

2626
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):

tests/accelerators/test_tpu.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
from pytorch_lightning.accelerators import TPUAccelerator
6+
from pytorch_lightning.plugins import SingleTPUPlugin, DDPPlugin, PrecisionPlugin
7+
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
8+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
9+
10+
11+
def test_unsupported_precision_plugins():
12+
""" Test error messages are raised for unsupported precision plugins with TPU. """
13+
trainer = Mock()
14+
model = Mock()
15+
accelerator = TPUAccelerator(
16+
training_type_plugin=SingleTPUPlugin(device=Mock()),
17+
precision_plugin=MixedPrecisionPlugin(),
18+
)
19+
with pytest.raises(MisconfigurationException, match=r"amp \+ tpu is not supported."):
20+
accelerator.setup(trainer=trainer, model=model)
21+
22+
23+
def test_unsupported_training_type_plugins():
24+
""" Test error messages are raised for unsupported training type with TPU. """
25+
trainer = Mock()
26+
model = Mock()
27+
accelerator = TPUAccelerator(
28+
training_type_plugin=DDPPlugin(),
29+
precision_plugin=PrecisionPlugin(),
30+
)
31+
with pytest.raises(MisconfigurationException, match="TPUs only support a single tpu core or tpu spawn training"):
32+
accelerator.setup(trainer=trainer, model=model)

0 commit comments

Comments
 (0)