1616import threading
1717from collections .abc import Iterable , Mapping
1818from itertools import chain
19- from typing import Optional
19+ from typing import Any , Optional
2020
2121import torch
22- from torch .nn import Module
2322from torch import Tensor
2423from torch .cuda ._utils import _get_device_index
25- from torch .nn import DataParallel
26- from torch .nn .parallel ._functions import Gather
27- from typing import Any
24+ from torch .nn import DataParallel , Module
2825from torch .nn .parallel import DistributedDataParallel
26+ from torch .nn .parallel ._functions import Gather
2927
3028from pytorch_lightning .core .lightning import LightningModule
3129from pytorch_lightning .core .step_result import Result
@@ -170,7 +168,8 @@ def forward(self, *inputs, **kwargs):
170168 warn_if_output_is_none (output , "validation_step" )
171169 return output
172170
173- # In manual_optimization, we need to call reducer prepare_for_backward.
171+
172+ # In manual_optimization, we need to call reducer prepare_for_backward.
174173# TODO: Keep track of Pytorch DDP and update if there is a change
175174# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L692
176175def prepare_for_backward (model : DistributedDataParallel , output : Any ):
@@ -186,7 +185,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any):
186185 else :
187186 model .reducer .prepare_for_backward ([])
188187 else :
189- model .require_forward_param_sync = False
188+ model .require_forward_param_sync = False
190189
191190#
192191# class LightningDistributedDataParallel(DistributedDataParallel):
0 commit comments