File tree Expand file tree Collapse file tree 2 files changed +34
-2
lines changed
pytorch_lightning/accelerators Expand file tree Collapse file tree 2 files changed +34
-2
lines changed Original file line number Diff line number Diff line change @@ -19,8 +19,8 @@ class TPUAccelerator(Accelerator):
1919 def setup (self , trainer , model ):
2020 if isinstance (self .precision_plugin , MixedPrecisionPlugin ):
2121 raise MisconfigurationException (
22- "amp + tpu is not supported. "
23- "Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
22+ "amp + tpu is not supported."
23+ " Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
2424 )
2525
2626 if not isinstance (self .training_type_plugin , (SingleTPUPlugin , TPUSpawnPlugin )):
Original file line number Diff line number Diff line change 1+ from unittest .mock import Mock
2+
3+ import pytest
4+
5+ from pytorch_lightning .accelerators import TPUAccelerator
6+ from pytorch_lightning .plugins import SingleTPUPlugin , DDPPlugin , PrecisionPlugin
7+ from pytorch_lightning .plugins .precision import MixedPrecisionPlugin
8+ from pytorch_lightning .utilities .exceptions import MisconfigurationException
9+
10+
11+ def test_unsupported_precision_plugins ():
12+ """ Test error messages are raised for unsupported precision plugins with TPU. """
13+ trainer = Mock ()
14+ model = Mock ()
15+ accelerator = TPUAccelerator (
16+ training_type_plugin = SingleTPUPlugin (device = Mock ()),
17+ precision_plugin = MixedPrecisionPlugin (),
18+ )
19+ with pytest .raises (MisconfigurationException , match = r"amp \+ tpu is not supported." ):
20+ accelerator .setup (trainer = trainer , model = model )
21+
22+
23+ def test_unsupported_training_type_plugins ():
24+ """ Test error messages are raised for unsupported training type with TPU. """
25+ trainer = Mock ()
26+ model = Mock ()
27+ accelerator = TPUAccelerator (
28+ training_type_plugin = DDPPlugin (),
29+ precision_plugin = PrecisionPlugin (),
30+ )
31+ with pytest .raises (MisconfigurationException , match = "TPUs only support a single tpu core or tpu spawn training" ):
32+ accelerator .setup (trainer = trainer , model = model )
You can’t perform that action at this time.
0 commit comments