| 
12 | 12 | # See the License for the specific language governing permissions and  | 
13 | 13 | # limitations under the License.  | 
14 | 14 | 
 
  | 
 | 15 | +import contextlib  | 
15 | 16 | import logging  | 
 | 17 | +import sys  | 
16 | 18 | from collections import OrderedDict  | 
17 | 19 | from typing import Any, Dict, List, Optional, Tuple, Union  | 
18 | 20 | 
 
  | 
 | 
23 | 25 | from torch.utils.hooks import RemovableHandle  | 
24 | 26 | 
 
  | 
25 | 27 | import pytorch_lightning as pl  | 
26 |  | -from pytorch_lightning.utilities import _AcceleratorType, AMPType  | 
27 | 28 | from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8  | 
28 | 29 | from pytorch_lightning.utilities.warnings import WarningCache  | 
29 | 30 | 
 
  | 
@@ -261,16 +262,17 @@ def _forward_example_input(self) -> None:  | 
261 | 262 |         input_ = model.example_input_array  | 
262 | 263 |         input_ = model._apply_batch_transfer_handler(input_)  | 
263 | 264 | 
 
  | 
264 |  | -        if (  | 
265 |  | -            trainer is not None  | 
266 |  | -            and trainer.amp_backend == AMPType.NATIVE  | 
267 |  | -            and trainer._device_type != _AcceleratorType.TPU  | 
268 |  | -        ):  | 
269 |  | -            model.forward = torch.cuda.amp.autocast()(model.forward)  | 
270 |  | - | 
271 | 265 |         mode = model.training  | 
272 | 266 |         model.eval()  | 
273 |  | -        with torch.no_grad():  | 
 | 267 | + | 
 | 268 | +        if trainer is not None:  | 
 | 269 | +            forward_context = trainer.precision_plugin.forward_context()  | 
 | 270 | +        elif sys.version_info >= (3, 7):  | 
 | 271 | +            forward_context = contextlib.nullcontext()  | 
 | 272 | +        else:  | 
 | 273 | +            forward_context = contextlib.suppress()  | 
 | 274 | + | 
 | 275 | +        with torch.no_grad(), forward_context:  | 
274 | 276 |             # let the model hooks collect the input- and output shapes  | 
275 | 277 |             if isinstance(input_, (list, tuple)):  | 
276 | 278 |                 model(*input_)  | 
 | 
0 commit comments