1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from contextlib import contextmanager
15- from typing import Any , Callable , Dict , Generator
15+ from typing import Any , Callable , Dict , Generator , Union
1616
1717import torch
1818from torch .optim import LBFGS , Optimizer
1919
2020import pytorch_lightning as pl
2121from pytorch_lightning .plugins .precision .mixed import MixedPrecisionPlugin
22- from pytorch_lightning .utilities import _NATIVE_AMP_AVAILABLE , AMPType
22+ from pytorch_lightning .utilities import _NATIVE_AMP_AVAILABLE , _TORCH_GREATER_EQUAL_1_10 , AMPType
2323from pytorch_lightning .utilities .exceptions import MisconfigurationException
2424
2525
2626class NativeMixedPrecisionPlugin (MixedPrecisionPlugin ):
27- """Plugin for native mixed precision training with :mod:`torch.cuda.amp`."""
27+ """
28+ Plugin for native mixed precision training with :mod:`torch.cuda.amp`.
2829
29- def __init__ (self ) -> None :
30+ Args:
31+ precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
32+ """
33+
34+ def __init__ (self , precision : Union [int , str ] = 16 ) -> None :
3035 super ().__init__ ()
36+
3137 if not _NATIVE_AMP_AVAILABLE :
3238 raise MisconfigurationException (
3339 "You have asked for native AMP but your PyTorch version does not support it."
3440 " Consider upgrading with `pip install torch>=1.6`."
3541 )
36-
42+ self . _fast_dtype = self . _select_precision_dtype ( precision )
3743 self .backend = AMPType .NATIVE
38- self .scaler = torch .cuda .amp .GradScaler ()
44+ if not self .is_bfloat16 :
45+ self .scaler = torch .cuda .amp .GradScaler ()
46+
47+ def _select_precision_dtype (self , precision : Union [int , str ] = 16 ) -> torch .dtype :
48+ if precision == "bf16" :
49+ if not _TORCH_GREATER_EQUAL_1_10 :
50+ raise MisconfigurationException (
51+ "To use bfloat16 with native amp you must install torch greater or equal to 1.10."
52+ )
53+ return torch .bfloat16
54+ return torch .float16
55+
56+ @property
57+ def is_bfloat16 (self ) -> bool :
58+ return self ._fast_dtype == torch .bfloat16
3959
4060 def pre_backward (self , model : "pl.LightningModule" , closure_loss : torch .Tensor ) -> torch .Tensor :
61+ if self .is_bfloat16 :
62+ return super ().pre_backward (model , closure_loss )
4163 closure_loss = self .scaler .scale (closure_loss )
4264 return super ().pre_backward (model , closure_loss )
4365
@@ -49,6 +71,9 @@ def pre_optimizer_step(
4971 lambda_closure : Callable ,
5072 ** kwargs : Any ,
5173 ) -> bool :
74+ if self .is_bfloat16 :
75+ # skip scaler logic, as bfloat16 does not require scaler
76+ return super ().pre_optimizer_step (model , optimizer , optimizer_idx , lambda_closure , ** kwargs )
5277 if isinstance (optimizer , LBFGS ):
5378 raise MisconfigurationException (
5479 f"native PyTorch amp and lbfgs are not compatible (optimizer { optimizer_idx } )."
@@ -65,33 +90,39 @@ def pre_optimizer_step(
6590 self .scaler .update ()
6691 return False
6792
93+ def autocast_context_manager (self ) -> torch .cuda .amp .autocast :
94+ if self .is_bfloat16 :
95+ return torch .cuda .amp .autocast (fast_dtype = self ._fast_dtype )
96+ return torch .cuda .amp .autocast ()
97+
6898 @contextmanager
6999 def train_step_context (self ) -> Generator [None , None , None ]:
70100 """Enable autocast context"""
71- with torch . cuda . amp . autocast ():
101+ with self . autocast_context_manager ():
72102 yield
73103
74104 @contextmanager
75105 def val_step_context (self ) -> Generator [None , None , None ]:
76106 """Enable autocast context"""
77- with torch . cuda . amp . autocast ():
107+ with self . autocast_context_manager ():
78108 yield
79109
80110 @contextmanager
81111 def test_step_context (self ) -> Generator [None , None , None ]:
82112 """Enable autocast context"""
83- with torch . cuda . amp . autocast ():
113+ with self . autocast_context_manager ():
84114 yield
85115
86116 @contextmanager
87117 def predict_step_context (self ) -> Generator [None , None , None ]:
88118 """Enable autocast context"""
89- with torch . cuda . amp . autocast ():
119+ with self . autocast_context_manager ():
90120 yield
91121
92122 def on_load_checkpoint (self , checkpoint : Dict [str , Any ]) -> None :
93- if "native_amp_scaling_state" in checkpoint :
123+ if "native_amp_scaling_state" in checkpoint and not self . is_bfloat16 :
94124 self .scaler .load_state_dict (checkpoint ["native_amp_scaling_state" ])
95125
96126 def on_save_checkpoint (self , checkpoint : Dict [str , Any ]) -> None :
97- checkpoint ["native_amp_scaling_state" ] = self .scaler .state_dict ()
127+ if not self .is_bfloat16 :
128+ checkpoint ["native_amp_scaling_state" ] = self .scaler .state_dict ()
0 commit comments