Skip to content

Commit 93ce2d7

Browse files
rohitgr7justusschockawaelchli
authored
Avoid torch amp cuda warning with bf16 on cpu (#11161)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b64dea9 commit 93ce2d7

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
333333

334334
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
335335

336+
337+
- Fixed an incorrect warning being produced by the model summary when using `bf16` precision on CPU ([#11161](https://github.com/PyTorchLightning/pytorch-lightning/pull/11161))
338+
339+
340+
-
341+
342+
343+
-
344+
345+
336346
## [1.5.6] - 2021-12-15
337347

338348
### Fixed

pytorch_lightning/utilities/model_summary.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import logging
17+
import sys
1618
from collections import OrderedDict
1719
from typing import Any, Dict, List, Optional, Tuple, Union
1820

@@ -23,7 +25,6 @@
2325
from torch.utils.hooks import RemovableHandle
2426

2527
import pytorch_lightning as pl
26-
from pytorch_lightning.utilities import _AcceleratorType, AMPType
2728
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
2829
from pytorch_lightning.utilities.warnings import WarningCache
2930

@@ -261,16 +262,17 @@ def _forward_example_input(self) -> None:
261262
input_ = model.example_input_array
262263
input_ = model._apply_batch_transfer_handler(input_)
263264

264-
if (
265-
trainer is not None
266-
and trainer.amp_backend == AMPType.NATIVE
267-
and trainer._device_type != _AcceleratorType.TPU
268-
):
269-
model.forward = torch.cuda.amp.autocast()(model.forward)
270-
271265
mode = model.training
272266
model.eval()
273-
with torch.no_grad():
267+
268+
if trainer is not None:
269+
forward_context = trainer.precision_plugin.forward_context()
270+
elif sys.version_info >= (3, 7):
271+
forward_context = contextlib.nullcontext()
272+
else:
273+
forward_context = contextlib.suppress()
274+
275+
with torch.no_grad(), forward_context:
274276
# let the model hooks collect the input- and output shapes
275277
if isinstance(input_, (list, tuple)):
276278
model(*input_)

0 commit comments

Comments
 (0)