|
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.distributed as torch_distrib |
20 | | -import torch.distributed as dist |
21 | 20 | import torch.multiprocessing as mp |
22 | 21 | from torch.nn.parallel import DistributedDataParallel |
23 | 22 |
|
24 | 23 | from pytorch_lightning import _logger as log |
25 | 24 | from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp |
| 25 | +from pytorch_lightning.cluster_environments import ClusterEnvironment |
26 | 26 | from pytorch_lightning.core.lightning import LightningModule |
27 | 27 | from pytorch_lightning.distributed import LightningDistributed |
| 28 | +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin |
28 | 29 | from pytorch_lightning.plugins.rpc_plugin import RPCPlugin |
29 | 30 | from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType |
30 | 31 | from pytorch_lightning.utilities.cloud_io import atomic_save |
31 | 32 | from pytorch_lightning.utilities.cloud_io import load as pl_load |
32 | 33 | from pytorch_lightning.utilities.distributed import ( |
| 34 | + all_gather_ddp_if_available, |
33 | 35 | find_free_network_port, |
34 | 36 | rank_zero_only, |
35 | 37 | rank_zero_warn, |
36 | 38 | sync_ddp_if_available, |
37 | | - all_gather_ddp_if_available, |
38 | 39 | ) |
39 | 40 | from pytorch_lightning.utilities.seed import seed_everything |
40 | 41 |
|
|
45 | 46 |
|
46 | 47 | class DDPSpawnAccelerator(Accelerator): |
47 | 48 |
|
48 | | - def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None): |
| 49 | + def __init__(self, |
| 50 | + trainer, |
| 51 | + nprocs: int, |
| 52 | + cluster_environment: Optional[ClusterEnvironment] = None, |
| 53 | + ddp_plugin: Optional[DDPPlugin] = None): |
49 | 54 | """ |
50 | 55 | Runs training using DDP using mp.spawn via manual launch (not cluster launch) |
51 | 56 |
|
@@ -226,8 +231,8 @@ def barrier(self, name: Optional[str] = None): |
226 | 231 |
|
227 | 232 | def early_stopping_should_stop(self, pl_module): |
228 | 233 | stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) |
229 | | - dist.all_reduce(stop, op=dist.reduce_op.SUM) |
230 | | - dist.barrier() |
| 234 | + torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) |
| 235 | + torch_distrib.barrier() |
231 | 236 | should_stop = stop == self.trainer.world_size |
232 | 237 | return should_stop |
233 | 238 |
|
|
0 commit comments