diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a70c9991abac..4a6efdae4b1f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -322,6 +322,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119)) + +- Fixed an incorrect warning being produced by the model summary when using `bf16` precision on CPU ([#11161](https://github.com/PyTorchLightning/pytorch-lightning/pull/11161)) + + +- + + +- + + ## [1.5.6] - 2021-12-15 ### Fixed diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index 37ff258436568..ccc81fc46d15f 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import logging +import sys from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union @@ -23,7 +25,6 @@ from torch.utils.hooks import RemovableHandle import pytorch_lightning as pl -from pytorch_lightning.utilities import _AcceleratorType, AMPType from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.warnings import WarningCache @@ -261,16 +262,17 @@ def _forward_example_input(self) -> None: input_ = model.example_input_array input_ = model._apply_batch_transfer_handler(input_) - if ( - trainer is not None - and trainer.amp_backend == AMPType.NATIVE - and trainer._device_type != _AcceleratorType.TPU - ): - model.forward = torch.cuda.amp.autocast()(model.forward) - mode = model.training model.eval() - with torch.no_grad(): + + if trainer is not None: + forward_context = trainer.precision_plugin.forward_context() + elif sys.version_info >= (3, 7): + forward_context = contextlib.nullcontext() + else: + forward_context = contextlib.suppress() + + with torch.no_grad(), forward_context: # let the model hooks collect the input- and output shapes if isinstance(input_, (list, tuple)): model(*input_)