Skip to content

Commit a1e038e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7903625 commit a1e038e

File tree

4 files changed

+24
-29
lines changed

4 files changed

+24
-29
lines changed

pytorch_lightning/strategies/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
2020
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
2121
from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401
22-
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy # noqa: F401
2322
from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401
23+
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy # noqa: F401
2424
from pytorch_lightning.strategies.horovod import HorovodStrategy # noqa: F401
2525
from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401
2626
from pytorch_lightning.strategies.parallel import ParallelStrategy # noqa: F401

pytorch_lightning/strategies/fully_sharded_native.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,31 @@
1414
import contextlib
1515
import logging
1616
import os
17-
from typing import Union, Any, Generator, Dict, List, Optional
17+
from typing import Any, Dict, Generator, List, Optional, Union
1818

19-
import pytorch_lightning as pl
2019
import torch
20+
from torch.distributed.distributed_c10d import _get_default_group
21+
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, FullyShardedDataParallel
22+
from torch.distributed.fsdp.wrap import enable_wrap
23+
24+
import pytorch_lightning as pl
2125
from pytorch_lightning.overrides.distributed import prepare_for_backward
22-
from pytorch_lightning.plugins.environments.cluster_environment import (
23-
ClusterEnvironment,
24-
)
26+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2527
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2628
from pytorch_lightning.plugins.precision import PrecisionPlugin
2729
from pytorch_lightning.strategies.parallel import ParallelStrategy
2830
from pytorch_lightning.utilities import rank_zero_only
2931
from pytorch_lightning.utilities.distributed import (
30-
init_dist_connection,
31-
sync_ddp_if_available,
32-
ReduceOp,
33-
group as _group,
3432
_get_process_group_backend_from_env,
3533
distributed_available,
3634
get_default_process_group_backend_for_device,
3735
)
36+
from pytorch_lightning.utilities.distributed import group as _group
37+
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3939
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
4040
from pytorch_lightning.utilities.optimizer import optimizers_to_device
4141
from pytorch_lightning.utilities.seed import reset_seed
42-
from torch.distributed.distributed_c10d import _get_default_group
43-
from torch.distributed.fsdp.fully_sharded_data_parallel import (
44-
BackwardPrefetch,
45-
CPUOffload,
46-
FullyShardedDataParallel,
47-
)
48-
from torch.distributed.fsdp.wrap import enable_wrap
49-
5042

5143
log = logging.getLogger(__name__)
5244

@@ -103,9 +95,7 @@ def __init__(
10395
precision_plugin=precision_plugin,
10496
)
10597
self._process_group = None
106-
self.num_processes = (
107-
len(self.parallel_devices) if self.parallel_devices is not None else 0
108-
)
98+
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
10999
self._has_loaded_state_dict: bool = False
110100
self._process_group_backend: Optional[str] = process_group_backend
111101
self.cpu_offload = cpu_offload
@@ -173,9 +163,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
173163

174164
def model_to_device(self) -> None:
175165
# ensure we update the device type in the lightning module
176-
log.info(
177-
f"{self.__class__.__name__}: moving model to device [{self.root_device}]..."
178-
)
166+
log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
179167
self.lightning_module.to(self.root_device)
180168

181169
@contextlib.contextmanager

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,10 @@ def _validate_precision_choice(self) -> None:
728728
"it's not supported. Try using `amp_type='native'` instead."
729729
)
730730
if self._precision_flag in (16, "bf16") and self._amp_type_flag == AMPType.APEX:
731-
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy, DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
731+
if isinstance(
732+
self.strategy,
733+
(DDPShardedStrategy, DDPSpawnShardedStrategy, DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy),
734+
):
732735
raise MisconfigurationException(
733736
"Sharded plugins are not supported with apex, please switch to `amp_backend='native'`."
734737
)

tests/strategies/test_ddp_fully_sharded_native.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@
44

55
import pytest
66
import torch
7+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
8+
from torch.distributed.fsdp.wrap import wrap
9+
710
from pytorch_lightning import Trainer
811
from pytorch_lightning.callbacks import ModelCheckpoint
912
from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin
1013
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
1114
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1215
from tests.helpers.boring_model import BoringModel
1316
from tests.helpers.runif import RunIf
14-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
15-
from torch.distributed.fsdp.wrap import wrap
1617

1718

1819
def test_invalid_on_cpu(tmpdir):
1920
"""Test to ensure that to raise Misconfiguration for Native FSDP on CPU."""
2021
with pytest.raises(
21-
MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded_native`, but GPU is not available."
22+
MisconfigurationException,
23+
match="You selected strategy to be `ddp_fully_sharded_native`, but GPU is not available.",
2224
):
2325
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
2426
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
@@ -102,7 +104,9 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
102104

103105
model = TestFSDPModel()
104106
ck = ModelCheckpoint(save_last=True)
105-
trainer = Trainer(default_root_dir=tmpdir, gpus=2, strategy="fsdp_native", precision=16, max_epochs=1, callbacks=[ck])
107+
trainer = Trainer(
108+
default_root_dir=tmpdir, gpus=2, strategy="fsdp_native", precision=16, max_epochs=1, callbacks=[ck]
109+
)
106110
_run_multiple_stages(trainer, model)
107111

108112

0 commit comments

Comments
 (0)