|
14 | 14 | import contextlib |
15 | 15 | import logging |
16 | 16 | import os |
17 | | -from typing import Union, Any, Generator, Dict, List, Optional |
| 17 | +from typing import Any, Dict, Generator, List, Optional, Union |
18 | 18 |
|
19 | | -import pytorch_lightning as pl |
20 | 19 | import torch |
| 20 | +from torch.distributed.distributed_c10d import _get_default_group |
| 21 | +from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, FullyShardedDataParallel |
| 22 | +from torch.distributed.fsdp.wrap import enable_wrap |
| 23 | + |
| 24 | +import pytorch_lightning as pl |
21 | 25 | from pytorch_lightning.overrides.distributed import prepare_for_backward |
22 | | -from pytorch_lightning.plugins.environments.cluster_environment import ( |
23 | | - ClusterEnvironment, |
24 | | -) |
| 26 | +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment |
25 | 27 | from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO |
26 | 28 | from pytorch_lightning.plugins.precision import PrecisionPlugin |
27 | 29 | from pytorch_lightning.strategies.parallel import ParallelStrategy |
28 | 30 | from pytorch_lightning.utilities import rank_zero_only |
29 | 31 | from pytorch_lightning.utilities.distributed import ( |
30 | | - init_dist_connection, |
31 | | - sync_ddp_if_available, |
32 | | - ReduceOp, |
33 | | - group as _group, |
34 | 32 | _get_process_group_backend_from_env, |
35 | 33 | distributed_available, |
36 | 34 | get_default_process_group_backend_for_device, |
37 | 35 | ) |
| 36 | +from pytorch_lightning.utilities.distributed import group as _group |
| 37 | +from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available |
38 | 38 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
39 | 39 | from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 |
40 | 40 | from pytorch_lightning.utilities.optimizer import optimizers_to_device |
41 | 41 | from pytorch_lightning.utilities.seed import reset_seed |
42 | | -from torch.distributed.distributed_c10d import _get_default_group |
43 | | -from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
44 | | - BackwardPrefetch, |
45 | | - CPUOffload, |
46 | | - FullyShardedDataParallel, |
47 | | -) |
48 | | -from torch.distributed.fsdp.wrap import enable_wrap |
49 | | - |
50 | 42 |
|
51 | 43 | log = logging.getLogger(__name__) |
52 | 44 |
|
@@ -103,9 +95,7 @@ def __init__( |
103 | 95 | precision_plugin=precision_plugin, |
104 | 96 | ) |
105 | 97 | self._process_group = None |
106 | | - self.num_processes = ( |
107 | | - len(self.parallel_devices) if self.parallel_devices is not None else 0 |
108 | | - ) |
| 98 | + self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 |
109 | 99 | self._has_loaded_state_dict: bool = False |
110 | 100 | self._process_group_backend: Optional[str] = process_group_backend |
111 | 101 | self.cpu_offload = cpu_offload |
@@ -173,9 +163,7 @@ def setup(self, trainer: "pl.Trainer") -> None: |
173 | 163 |
|
174 | 164 | def model_to_device(self) -> None: |
175 | 165 | # ensure we update the device type in the lightning module |
176 | | - log.info( |
177 | | - f"{self.__class__.__name__}: moving model to device [{self.root_device}]..." |
178 | | - ) |
| 166 | + log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") |
179 | 167 | self.lightning_module.to(self.root_device) |
180 | 168 |
|
181 | 169 | @contextlib.contextmanager |
|
0 commit comments