Skip to content

Commit c3fc031

Browse files
lillekemikerMartin KristiansenawaelchliBorda
authored
Updating docs and error message: half precision not available on CPU (#7384)
* Updating docs and error message to specify that half precission not available on CPU * update messages Co-authored-by: Martin Kristiansen <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: jirka <[email protected]>
1 parent dea7a02 commit c3fc031

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

docs/source/common/trainer.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ precision
11561156
|
11571157
11581158
Double precision (64), full precision (32) or half precision (16).
1159-
Can be used on CPU, GPU or TPUs.
1159+
Can all be used on GPU or TPUs. Only double (64) and full precision (32) available on CPU.
11601160

11611161
If used on TPU will use torch.bfloat16 but tensor printing
11621162
will still show torch.float32.

pytorch_lightning/accelerators/cpu.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,15 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
2727
If AMP is used with CPU, or if the selected device is not CPU.
2828
"""
2929
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
30-
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
30+
raise MisconfigurationException(
31+
" Mixed precision is currenty only supported with the AMP backend"
32+
" and AMP + CPU is not supported. Please use a GPU option or"
33+
" change precision setting."
34+
)
3135

3236
if "cpu" not in str(self.root_device):
33-
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead")
37+
raise MisconfigurationException(
38+
f"Device should be CPU, got {self.root_device} instead."
39+
)
3440

3541
return super().setup(trainer, model)

tests/accelerators/test_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_unsupported_precision_plugins():
1818
accelerator = CPUAccelerator(
1919
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
2020
)
21-
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
21+
with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"):
2222
accelerator.setup(trainer=trainer, model=model)
2323

2424

0 commit comments

Comments
 (0)