Skip to content

Commit c839b0d

Browse files
committed
resolve flake8
1 parent 8adc02c commit c839b0d

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

pytorch_lightning/overrides/data_parallel.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
import threading
1717
from collections.abc import Iterable, Mapping
1818
from itertools import chain
19-
from typing import Optional
19+
from typing import Any, Optional
2020

2121
import torch
22-
from torch.nn import Module
2322
from torch import Tensor
2423
from torch.cuda._utils import _get_device_index
25-
from torch.nn import DataParallel
26-
from torch.nn.parallel._functions import Gather
27-
from typing import Any
24+
from torch.nn import DataParallel, Module
2825
from torch.nn.parallel import DistributedDataParallel
26+
from torch.nn.parallel._functions import Gather
2927

3028
from pytorch_lightning.core.lightning import LightningModule
3129
from pytorch_lightning.core.step_result import Result
@@ -170,7 +168,8 @@ def forward(self, *inputs, **kwargs):
170168
warn_if_output_is_none(output, "validation_step")
171169
return output
172170

173-
# In manual_optimization, we need to call reducer prepare_for_backward.
171+
172+
# In manual_optimization, we need to call reducer prepare_for_backward.
174173
# TODO: Keep track of Pytorch DDP and update if there is a change
175174
# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L692
176175
def prepare_for_backward(model: DistributedDataParallel, output: Any):
@@ -186,7 +185,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any):
186185
else:
187186
model.reducer.prepare_for_backward([])
188187
else:
189-
model.require_forward_param_sync = False
188+
model.require_forward_param_sync = False
190189

191190
#
192191
# class LightningDistributedDataParallel(DistributedDataParallel):

pytorch_lightning/plugins/sharded_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import List, Optional, Union, Any
14+
from typing import Any, List, Optional, Union
1515

1616
from pytorch_lightning.core.lightning import LightningModule
1717
from pytorch_lightning.core.optimizer import is_lightning_optimizer
@@ -96,4 +96,4 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list:
9696
return []
9797

9898
def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
99-
pass
99+
pass

0 commit comments

Comments
 (0)