Skip to content

Commit 1feec8c

Browse files
author
Sean Naren
authored
Add bfloat16 support to Lightning Trainer (#9049)
1 parent 208218b commit 1feec8c

File tree

11 files changed

+137
-72
lines changed

11 files changed

+137
-72
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6969
- Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))
7070

7171

72+
- Added bfloat16 support for Lightning Trainer ([#9049](https://github.com/PyTorchLightning/pytorch-lightning/pull/9049))
73+
74+
7275
- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047))
7376

7477

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,54 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Any, Callable, Dict, Generator
15+
from typing import Any, Callable, Dict, Generator, Union
1616

1717
import torch
1818
from torch.optim import LBFGS, Optimizer
1919

2020
import pytorch_lightning as pl
2121
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
22-
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
22+
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424

2525

2626
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
27-
"""Plugin for native mixed precision training with :mod:`torch.cuda.amp`."""
27+
"""
28+
Plugin for native mixed precision training with :mod:`torch.cuda.amp`.
2829
29-
def __init__(self) -> None:
30+
Args:
31+
precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
32+
"""
33+
34+
def __init__(self, precision: Union[int, str] = 16) -> None:
3035
super().__init__()
36+
3137
if not _NATIVE_AMP_AVAILABLE:
3238
raise MisconfigurationException(
3339
"You have asked for native AMP but your PyTorch version does not support it."
3440
" Consider upgrading with `pip install torch>=1.6`."
3541
)
36-
42+
self._fast_dtype = self._select_precision_dtype(precision)
3743
self.backend = AMPType.NATIVE
38-
self.scaler = torch.cuda.amp.GradScaler()
44+
if not self.is_bfloat16:
45+
self.scaler = torch.cuda.amp.GradScaler()
46+
47+
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
48+
if precision == "bf16":
49+
if not _TORCH_GREATER_EQUAL_1_10:
50+
raise MisconfigurationException(
51+
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
52+
)
53+
return torch.bfloat16
54+
return torch.float16
55+
56+
@property
57+
def is_bfloat16(self) -> bool:
58+
return self._fast_dtype == torch.bfloat16
3959

4060
def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor:
61+
if self.is_bfloat16:
62+
return super().pre_backward(model, closure_loss)
4163
closure_loss = self.scaler.scale(closure_loss)
4264
return super().pre_backward(model, closure_loss)
4365

@@ -49,6 +71,9 @@ def pre_optimizer_step(
4971
lambda_closure: Callable,
5072
**kwargs: Any,
5173
) -> bool:
74+
if self.is_bfloat16:
75+
# skip scaler logic, as bfloat16 does not require scaler
76+
return super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
5277
if isinstance(optimizer, LBFGS):
5378
raise MisconfigurationException(
5479
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
@@ -65,33 +90,39 @@ def pre_optimizer_step(
6590
self.scaler.update()
6691
return False
6792

93+
def autocast_context_manager(self) -> torch.cuda.amp.autocast:
94+
if self.is_bfloat16:
95+
return torch.cuda.amp.autocast(fast_dtype=self._fast_dtype)
96+
return torch.cuda.amp.autocast()
97+
6898
@contextmanager
6999
def train_step_context(self) -> Generator[None, None, None]:
70100
"""Enable autocast context"""
71-
with torch.cuda.amp.autocast():
101+
with self.autocast_context_manager():
72102
yield
73103

74104
@contextmanager
75105
def val_step_context(self) -> Generator[None, None, None]:
76106
"""Enable autocast context"""
77-
with torch.cuda.amp.autocast():
107+
with self.autocast_context_manager():
78108
yield
79109

80110
@contextmanager
81111
def test_step_context(self) -> Generator[None, None, None]:
82112
"""Enable autocast context"""
83-
with torch.cuda.amp.autocast():
113+
with self.autocast_context_manager():
84114
yield
85115

86116
@contextmanager
87117
def predict_step_context(self) -> Generator[None, None, None]:
88118
"""Enable autocast context"""
89-
with torch.cuda.amp.autocast():
119+
with self.autocast_context_manager():
90120
yield
91121

92122
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
93-
if "native_amp_scaling_state" in checkpoint:
123+
if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16:
94124
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])
95125

96126
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
97-
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
127+
if not self.is_bfloat16:
128+
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
2525
"""Mixed Precision for Sharded Training"""
2626

27-
def __init__(self) -> None:
28-
super().__init__()
27+
def __init__(self, precision: Union[int, str] = 16) -> None:
28+
super().__init__(precision)
2929
self.scaler = ShardedGradScaler()
3030

3131
def clip_grad_by_norm(

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
560560
return PrecisionPlugin()
561561
if self.precision == 64:
562562
return DoublePrecisionPlugin()
563-
if self.precision == 16:
563+
if self.precision in (16, "bf16"):
564564
if self.use_tpu:
565565
return TPUHalfPrecisionPlugin()
566566

@@ -581,12 +581,12 @@ def select_precision_plugin(self) -> PrecisionPlugin:
581581
else:
582582
raise MisconfigurationException(msg)
583583
else:
584-
log.info("Using native 16bit precision.")
584+
log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
585585
if self._is_sharded_training_type:
586-
return ShardedNativeMixedPrecisionPlugin()
586+
return ShardedNativeMixedPrecisionPlugin(self.precision)
587587
if self._is_fully_sharded_training_type:
588-
return FullyShardedNativeMixedPrecisionPlugin()
589-
return NativeMixedPrecisionPlugin()
588+
return FullyShardedNativeMixedPrecisionPlugin(self.precision)
589+
return NativeMixedPrecisionPlugin(self.precision)
590590

591591
if self.amp_type == AMPType.APEX:
592592
if not _APEX_AVAILABLE:

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(
138138
log_every_n_steps: int = 50,
139139
accelerator: Optional[Union[str, Accelerator]] = None,
140140
sync_batchnorm: bool = False,
141-
precision: int = 32,
141+
precision: Union[int, str] = 32,
142142
weights_summary: Optional[str] = "top",
143143
weights_save_path: Optional[str] = None,
144144
num_sanity_val_steps: int = 2,
@@ -260,8 +260,8 @@ def __init__(
260260
261261
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
262262
263-
precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or
264-
TPUs.
263+
precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).
264+
Can be used on CPU, GPU or TPUs.
265265
266266
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
267267
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_TORCH_GREATER_EQUAL_1_7,
4848
_TORCH_GREATER_EQUAL_1_8,
4949
_TORCH_GREATER_EQUAL_1_9,
50+
_TORCH_GREATER_EQUAL_1_10,
5051
_TORCH_QUANTIZE_AVAILABLE,
5152
_TORCH_SHARDED_TENSOR_AVAILABLE,
5253
_TORCHTEXT_AVAILABLE,

pytorch_lightning/utilities/argparse.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,10 @@ def add_argparse_args(
253253
if arg == "track_grad_norm":
254254
use_type = float
255255

256+
# hack for precision
257+
if arg == "precision":
258+
use_type = _precision_allowed_type
259+
256260
parser.add_argument(
257261
f"--{arg}", dest=arg, default=arg_default, type=use_type, help=args_help.get(arg), **arg_kwargs
258262
)
@@ -302,3 +306,16 @@ def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
302306
if "." in str(x):
303307
return float(x)
304308
return int(x)
309+
310+
311+
def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
312+
"""
313+
>>> _precision_allowed_type("32")
314+
32
315+
>>> _precision_allowed_type("bf16")
316+
'bf16'
317+
"""
318+
try:
319+
return int(x)
320+
except ValueError:
321+
return x

pytorch_lightning/utilities/imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def _compare_version(package: str, op, version) -> bool:
6868
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
6969
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
7070
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
71+
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
72+
7173

7274
_APEX_AVAILABLE = _module_available("apex.amp")
7375
_BOLTS_AVAILABLE = _module_available("pl_bolts")

tests/models/test_amp.py

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tests.helpers.utils as tutils
2323
from pytorch_lightning import Trainer
2424
from pytorch_lightning.plugins.environments import SLURMEnvironment
25+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
2526
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2627
from tests.helpers import BoringModel, RandomDataset
2728
from tests.helpers.runif import RunIf
@@ -31,7 +32,8 @@ class AMPTestModel(BoringModel):
3132
def _step(self, batch, batch_idx):
3233
assert torch.is_autocast_enabled()
3334
output = self(batch)
34-
assert output.dtype == torch.float16
35+
bfloat16 = self.trainer.precision_plugin.is_bfloat16
36+
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
3537
loss = self.loss(batch, output)
3638
return loss
3739

@@ -50,17 +52,35 @@ def test_step(self, batch, batch_idx):
5052
def predict(self, batch, batch_idx, dataloader_idx=None):
5153
assert torch.is_autocast_enabled()
5254
output = self(batch)
53-
assert output.dtype == torch.float16
55+
bfloat16 = self.trainer.precision_plugin.is_bfloat16
56+
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
5457
return output
5558

5659

57-
@pytest.mark.skip(reason="dp + amp not supported currently") # TODO
58-
@RunIf(min_gpus=1)
59-
def test_amp_single_gpu_dp(tmpdir):
60-
"""Make sure DP/DDP + AMP work."""
60+
@RunIf(min_gpus=2)
61+
@pytest.mark.parametrize(
62+
"accelerator",
63+
[
64+
pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO
65+
"ddp_spawn",
66+
],
67+
)
68+
@pytest.mark.parametrize(
69+
"precision",
70+
[
71+
16,
72+
pytest.param(
73+
"bf16",
74+
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"),
75+
),
76+
],
77+
)
78+
@pytest.mark.parametrize("gpus", [1, 2])
79+
def test_amp_gpus(tmpdir, accelerator, precision, gpus):
80+
"""Make sure combinations of AMP and training types work if supported."""
6181
tutils.reset_seed()
6282

63-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="dp", precision=16)
83+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=gpus, accelerator=accelerator, precision=precision)
6484

6585
model = AMPTestModel()
6686
# tutils.run_model_test(trainer_options, model)
@@ -71,49 +91,6 @@ def test_amp_single_gpu_dp(tmpdir):
7191
assert trainer.state.finished, f"Training failed with {trainer.state}"
7292

7393

74-
@RunIf(min_gpus=1)
75-
def test_amp_single_gpu_ddp_spawn(tmpdir):
76-
"""Make sure DP/DDP + AMP work."""
77-
tutils.reset_seed()
78-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="ddp_spawn", precision=16)
79-
80-
model = AMPTestModel()
81-
# tutils.run_model_test(trainer_options, model)
82-
trainer.fit(model)
83-
trainer.test(model)
84-
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
85-
assert trainer.state.finished, f"Training failed with {trainer.state}"
86-
87-
88-
@pytest.mark.skip(reason="dp + amp not supported currently") # TODO
89-
@RunIf(min_gpus=1)
90-
def test_amp_multi_gpu_dp(tmpdir):
91-
"""Make sure DP/DDP + AMP work."""
92-
tutils.reset_seed()
93-
94-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, accelerator="dp", precision=16)
95-
96-
model = AMPTestModel()
97-
# tutils.run_model_test(trainer_options, model)
98-
trainer.fit(model)
99-
100-
assert trainer.state.finished, f"Training failed with {trainer.state}"
101-
102-
103-
@RunIf(min_gpus=2)
104-
def test_amp_multi_gpu_ddp_spawn(tmpdir):
105-
"""Make sure DP/DDP + AMP work."""
106-
tutils.reset_seed()
107-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, accelerator="ddp_spawn", precision=16)
108-
109-
model = AMPTestModel()
110-
# tutils.run_model_test(trainer_options, model)
111-
trainer.fit(model)
112-
trainer.test(model)
113-
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
114-
assert trainer.state.finished, f"Training failed with {trainer.state}"
115-
116-
11794
@RunIf(min_gpus=2)
11895
@mock.patch.dict(
11996
os.environ,

tests/plugins/test_amp_plugins.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
2323
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
24+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2425
from tests.helpers import BoringModel
2526
from tests.helpers.runif import RunIf
2627

@@ -69,6 +70,8 @@ def test_amp_apex_ddp(
6970
plugins=[plugin_cls()] if custom_plugin else None,
7071
)
7172
assert isinstance(trainer.precision_plugin, plugin_cls)
73+
if amp == "native":
74+
assert not trainer.precision_plugin.is_bfloat16
7275

7376

7477
class GradientUnscaleBoringModel(BoringModel):
@@ -174,3 +177,16 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
174177
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
175178
model = BoringModel()
176179
trainer.fit(model)
180+
181+
182+
@RunIf(min_gpus=1, amp_native=True, max_torch="1.9")
183+
def test_amp_precision_16_bfloat_throws_error(tmpdir):
184+
with pytest.raises(
185+
MisconfigurationException,
186+
match="To use bfloat16 with native amp you must install torch greater or equal to 1.10",
187+
):
188+
Trainer(
189+
default_root_dir=tmpdir,
190+
precision="bf16",
191+
gpus=1,
192+
)

0 commit comments

Comments
 (0)