Skip to content

Commit 0461107

Browse files
Move init_ddp_connection to distributed utilities (#9044)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8a93173 commit 0461107

File tree

3 files changed

+42
-44
lines changed

3 files changed

+42
-44
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from pytorch_lightning.utilities.distributed import (
4747
distributed_available,
48-
rank_zero_info,
48+
init_ddp_connection,
4949
rank_zero_only,
5050
ReduceOp,
5151
sync_ddp_if_available,
@@ -253,7 +253,7 @@ def setup_distributed(self):
253253
# set up server using proc 0's ip address
254254
# try to init for 20 times at max in case ports are taken
255255
# where to store ip_table
256-
self.init_ddp_connection()
256+
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
257257

258258
# set the ranks and devices
259259
self.dist.rank = self.global_rank
@@ -316,25 +316,6 @@ def determine_ddp_device_ids(self):
316316
return None
317317
return [self.root_device.index]
318318

319-
def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None:
320-
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
321-
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
322-
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
323-
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
324-
if torch.distributed.is_available() and not torch.distributed.is_initialized():
325-
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
326-
torch.distributed.init_process_group(
327-
self.torch_distributed_backend, rank=global_rank, world_size=world_size
328-
)
329-
330-
# on rank=0 let everyone know training is starting
331-
rank_zero_info(
332-
f"{'-' * 100}\n"
333-
f"distributed_backend={self.torch_distributed_backend}\n"
334-
f"All DDP processes registered. Starting ddp with {self.world_size} processes\n"
335-
f"{'-' * 100}\n"
336-
)
337-
338319
def pre_dispatch(self):
339320
# move the model to the correct device
340321
self.model_to_device()

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from pytorch_lightning.utilities.cloud_io import load as pl_load
4141
from pytorch_lightning.utilities.distributed import (
4242
distributed_available,
43-
rank_zero_info,
43+
init_ddp_connection,
4444
rank_zero_only,
4545
ReduceOp,
4646
sync_ddp_if_available,
@@ -185,7 +185,7 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
185185
# set up server using proc 0's ip address
186186
# try to init for 20 times at max in case ports are taken
187187
# where to store ip_table
188-
self.init_ddp_connection(self.global_rank, self.world_size)
188+
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size)
189189

190190
# TODO: we moved it to the trainer.fit after calling pre_dispatch
191191
# ... need to double check that it is the correct place
@@ -261,27 +261,6 @@ def configure_ddp(self):
261261
)
262262
self._register_ddp_hooks()
263263

264-
def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[int]) -> None:
265-
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
266-
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
267-
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
268-
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
269-
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
270-
271-
if not torch.distributed.is_initialized():
272-
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
273-
torch.distributed.init_process_group(
274-
self.torch_distributed_backend, rank=global_rank, world_size=world_size
275-
)
276-
277-
# on rank=0 let everyone know training is starting
278-
rank_zero_info(
279-
f"{'-' * 100}\n"
280-
f"distributed_backend={self.torch_distributed_backend}\n"
281-
f"All DDP processes registered. Starting ddp with {self.world_size} processes\n"
282-
f"{'-' * 100}\n"
283-
)
284-
285264
def determine_ddp_device_ids(self):
286265
if self.root_device.type == "cpu":
287266
return None

pytorch_lightning/utilities/distributed.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from torch.nn.parallel.distributed import DistributedDataParallel
2323

24+
import pytorch_lightning as pl
2425
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
2526

2627
if _TPU_AVAILABLE:
@@ -345,3 +346,40 @@ def register_ddp_comm_hook(
345346

346347
def tpu_distributed() -> bool:
347348
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
349+
350+
351+
def init_ddp_connection(
352+
cluster_environment: "pl.plugins.environments.ClusterEnvironment",
353+
torch_distributed_backend: str,
354+
global_rank: Optional[int] = None,
355+
world_size: Optional[int] = None,
356+
**kwargs,
357+
) -> None:
358+
"""
359+
Utility function to initialize DDP connection by setting env variables
360+
and initiliazing the distributed process group.
361+
362+
Args:
363+
cluster_environment: ``ClusterEnvironment`` instance
364+
torch_distributed_backend: backend to use (includes `nccl` and `gloo`)
365+
global_rank: rank of the current process
366+
world_size: number of processes in the group
367+
kwargs: kwargs for ``init_process_group``
368+
"""
369+
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
370+
world_size = world_size if world_size is not None else cluster_environment.world_size()
371+
os.environ["MASTER_ADDR"] = cluster_environment.master_address()
372+
os.environ["MASTER_PORT"] = str(cluster_environment.master_port())
373+
if torch.distributed.is_available() and not torch.distributed.is_initialized():
374+
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
375+
torch.distributed.init_process_group(
376+
torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs
377+
)
378+
379+
# on rank=0 let everyone know training is starting
380+
rank_zero_info(
381+
f"{'-' * 100}\n"
382+
f"distributed_backend={torch_distributed_backend}\n"
383+
f"All DDP processes registered. Starting ddp with {world_size} processes\n"
384+
f"{'-' * 100}\n"
385+
)

0 commit comments

Comments
 (0)