Skip to content

Commit ff38b10

Browse files
committed
Reuse code
1 parent 04644c6 commit ff38b10

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@
4141
rank_zero_deprecation,
4242
rank_zero_warn,
4343
)
44-
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
44+
from pytorch_lightning.utilities.distributed import (
45+
distributed_available,
46+
rank_zero_info,
47+
rank_zero_only,
48+
ReduceOp,
49+
sync_ddp_if_available,
50+
)
4551
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
4652
from pytorch_lightning.utilities.seed import reset_seed
4753

@@ -333,7 +339,7 @@ def post_dispatch(self) -> None:
333339
self.cluster_environment.teardown()
334340

335341
def barrier(self, *args, **kwargs) -> None:
336-
if not torch_distrib.is_initialized():
342+
if not distributed_available():
337343
return
338344
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
339345
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
)
3737
from pytorch_lightning.utilities.cloud_io import atomic_save
3838
from pytorch_lightning.utilities.cloud_io import load as pl_load
39-
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
39+
from pytorch_lightning.utilities.distributed import (
40+
distributed_available,
41+
rank_zero_info,
42+
rank_zero_only,
43+
ReduceOp,
44+
sync_ddp_if_available,
45+
)
4046
from pytorch_lightning.utilities.seed import reset_seed
4147

4248
if _TORCH_GREATER_EQUAL_1_8:
@@ -310,7 +316,7 @@ def __recover_child_process_weights(self, best_path, last_path):
310316
self.lightning_module.load_state_dict(ckpt)
311317

312318
def barrier(self, *args, **kwargs) -> None:
313-
if not torch_distrib.is_initialized():
319+
if not distributed_available():
314320
return
315321
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
316322
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from typing import Any, List, Optional, Union
1616

1717
import torch
18-
import torch.distributed as torch_distrib
1918
from torch.optim.lr_scheduler import _LRScheduler, Optimizer
2019

2120
from pytorch_lightning.core.optimizer import LightningOptimizer
2221
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
2322
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
24-
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp
23+
from pytorch_lightning.utilities.distributed import distributed_available, group, rank_zero_only, ReduceOp
2524

2625
if _HOROVOD_AVAILABLE:
2726
import horovod.torch as hvd
@@ -125,7 +124,7 @@ def start_predicting(self, trainer):
125124
self.join()
126125

127126
def barrier(self, *args, **kwargs):
128-
if torch_distrib.is_initialized():
127+
if distributed_available():
129128
self.join()
130129

131130
def broadcast(self, obj: object, src: int = 0) -> object:

0 commit comments

Comments
 (0)