| 
12 | 12 | # See the License for the specific language governing permissions and  | 
13 | 13 | # limitations under the License.  | 
14 | 14 | 
 
  | 
 | 15 | +import contextlib  | 
15 | 16 | import logging  | 
 | 17 | +import sys  | 
16 | 18 | from collections import OrderedDict  | 
17 | 19 | from typing import Any, Dict, List, Optional, Tuple, Union  | 
18 | 20 | 
 
  | 
 | 
23 | 25 | from torch.utils.hooks import RemovableHandle  | 
24 | 26 | 
 
  | 
25 | 27 | import pytorch_lightning as pl  | 
26 |  | -from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation  | 
 | 28 | +from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_deprecation  | 
27 | 29 | from pytorch_lightning.utilities.exceptions import MisconfigurationException  | 
28 | 30 | from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8  | 
29 | 31 | from pytorch_lightning.utilities.warnings import WarningCache  | 
@@ -282,12 +284,17 @@ def _forward_example_input(self) -> None:  | 
282 | 284 |         input_ = model.example_input_array  | 
283 | 285 |         input_ = model._apply_batch_transfer_handler(input_)  | 
284 | 286 | 
 
  | 
285 |  | -        if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU:  | 
286 |  | -            model.forward = torch.cuda.amp.autocast()(model.forward)  | 
287 |  | - | 
288 | 287 |         mode = model.training  | 
289 | 288 |         model.eval()  | 
290 |  | -        with torch.no_grad():  | 
 | 289 | + | 
 | 290 | +        if trainer is not None:  | 
 | 291 | +            forward_context = trainer.precision_plugin.forward_context()  | 
 | 292 | +        elif sys.version_info >= (3, 7):  | 
 | 293 | +            forward_context = contextlib.nullcontext()  | 
 | 294 | +        else:  | 
 | 295 | +            forward_context = contextlib.suppress()  | 
 | 296 | + | 
 | 297 | +        with torch.no_grad(), forward_context:  | 
291 | 298 |             # let the model hooks collect the input- and output shapes  | 
292 | 299 |             if isinstance(input_, (list, tuple)):  | 
293 | 300 |                 model(*input_)  | 
 | 
0 commit comments