File tree Expand file tree Collapse file tree 3 files changed +11
-8
lines changed
pytorch_lightning/plugins/precision Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -206,6 +206,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
206206 * Added bfloat16 support for Lightning Trainer ([ #9049 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/9049 ) )
207207 * Renamed ` TPUHalfPrecisionPlugin ` to ` TPUBf16PrecisionPlugin ` ([ #10026 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/10026 ) )
208208 * Default to ` precision=bf16 ` on CPU when ` precision=16 ` is passed ([ #10033 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/10033 ) )
209+ * Add support for ` torch.autocast ` ([ #10053 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/10053 ) )
209210
210211
211212- Added ` kfold ` example for loop customization ([ #9965 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/9965 ) )
Original file line number Diff line number Diff line change 2424from pytorch_lightning .utilities import _TORCH_GREATER_EQUAL_DEV_1_10 , AMPType
2525from pytorch_lightning .utilities .exceptions import MisconfigurationException
2626
27+ if _TORCH_GREATER_EQUAL_DEV_1_10 :
28+ from torch import autocast
29+ else :
30+ from torch .cuda .amp import autocast
31+
2732
2833class NativeMixedPrecisionPlugin (MixedPrecisionPlugin ):
2934 """Plugin for native mixed precision training with :mod:`torch.cuda.amp`.
@@ -90,12 +95,10 @@ def pre_optimizer_step(
9095 self .scaler .update ()
9196 return False
9297
93- def autocast_context_manager (self ) -> torch .cuda .amp .autocast :
94- if self .use_cpu :
95- return torch .cpu .amp .autocast (dtype = self ._dtype ) # Only reached in pytorch==1.10 where this is ok. skipcq
96- if self .is_bfloat16 :
97- return torch .cuda .amp .autocast (dtype = self ._dtype ) # Only reached in pytorch==1.10 where this is ok. skipcq
98- return torch .cuda .amp .autocast ()
98+ def autocast_context_manager (self ) -> autocast :
99+ if _TORCH_GREATER_EQUAL_DEV_1_10 :
100+ return autocast ("cpu" if self .use_cpu else "cuda" , dtype = self ._dtype )
101+ return autocast ()
99102
100103 @contextmanager
101104 def forward_context (self ) -> Generator [None , None , None ]:
Original file line number Diff line number Diff line change @@ -181,12 +181,11 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
181181@pytest .mark .skipif (not _TORCH_GREATER_EQUAL_DEV_1_10 , reason = "Torch CPU AMP is not available." )
182182def test_cpu_amp_precision_context_manager (tmpdir ):
183183 """Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
184-
185184 plugin = NativeMixedPrecisionPlugin (precision = "bf16" , use_cpu = True )
186185 assert plugin .use_cpu
187186 assert not hasattr (plugin , "scaler" )
188187 context_manager = plugin .autocast_context_manager ()
189- assert isinstance (context_manager , torch .cpu . amp . autocast )
188+ assert isinstance (context_manager , torch .autocast )
190189 assert context_manager .fast_dtype == torch .bfloat16
191190
192191
You can’t perform that action at this time.
0 commit comments