Skip to content

Commit 2ee3127

Browse files
authored
Use torch.autocast (#10053)
1 parent 43c70ec commit 2ee3127

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
206206
* Added bfloat16 support for Lightning Trainer ([#9049](https://github.com/PyTorchLightning/pytorch-lightning/pull/9049))
207207
* Renamed `TPUHalfPrecisionPlugin` to `TPUBf16PrecisionPlugin` ([#10026](https://github.com/PyTorchLightning/pytorch-lightning/pull/10026))
208208
* Default to `precision=bf16` on CPU when `precision=16` is passed ([#10033](https://github.com/PyTorchLightning/pytorch-lightning/pull/10033))
209+
* Add support for `torch.autocast` ([#10053](https://github.com/PyTorchLightning/pytorch-lightning/pull/10053))
209210

210211

211212
- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10, AMPType
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626

27+
if _TORCH_GREATER_EQUAL_DEV_1_10:
28+
from torch import autocast
29+
else:
30+
from torch.cuda.amp import autocast
31+
2732

2833
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
2934
"""Plugin for native mixed precision training with :mod:`torch.cuda.amp`.
@@ -90,12 +95,10 @@ def pre_optimizer_step(
9095
self.scaler.update()
9196
return False
9297

93-
def autocast_context_manager(self) -> torch.cuda.amp.autocast:
94-
if self.use_cpu:
95-
return torch.cpu.amp.autocast(dtype=self._dtype) # Only reached in pytorch==1.10 where this is ok. skipcq
96-
if self.is_bfloat16:
97-
return torch.cuda.amp.autocast(dtype=self._dtype) # Only reached in pytorch==1.10 where this is ok. skipcq
98-
return torch.cuda.amp.autocast()
98+
def autocast_context_manager(self) -> autocast:
99+
if _TORCH_GREATER_EQUAL_DEV_1_10:
100+
return autocast("cpu" if self.use_cpu else "cuda", dtype=self._dtype)
101+
return autocast()
99102

100103
@contextmanager
101104
def forward_context(self) -> Generator[None, None, None]:

tests/plugins/test_amp_plugins.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,11 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
181181
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Torch CPU AMP is not available.")
182182
def test_cpu_amp_precision_context_manager(tmpdir):
183183
"""Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
184-
185184
plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
186185
assert plugin.use_cpu
187186
assert not hasattr(plugin, "scaler")
188187
context_manager = plugin.autocast_context_manager()
189-
assert isinstance(context_manager, torch.cpu.amp.autocast)
188+
assert isinstance(context_manager, torch.autocast)
190189
assert context_manager.fast_dtype == torch.bfloat16
191190

192191

0 commit comments

Comments
 (0)