Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions src/lightning/fabric/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class TorchCollective(Collective):
"""

manages_default_group = False
addr_key = "MASTER_ADDR"
port_key = "MASTER_PORT"

def __init__(self) -> None:
if not dist.is_available():
Expand Down Expand Up @@ -136,26 +138,21 @@ def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = N
if self.is_initialized():
return self
# maybe set addr
set_addr = False
addr_key = "MASTER_ADDR"
if main_address is not None and addr_key not in os.environ:
os.environ[addr_key] = main_address
set_addr = True
setting_env = []
if main_address is not None and self.addr_key not in os.environ:
os.environ[self.addr_key] = main_address
setting_env.append(self.addr_key)
# maybe set port
set_port = False
port_key = "MASTER_PORT"
if main_port is not None and port_key not in os.environ:
os.environ[port_key] = str(main_port)
set_port = True
if main_port is not None and self.port_key not in os.environ:
os.environ[self.port_key] = str(main_port)
setting_env.append(self.port_key)
# this will `init_group`
super().setup(**kwargs)
# set as a class attribute so any instance can know whether we initialized the default process group
TorchCollective.manages_default_group = True
# cleanup
if set_addr:
os.environ.pop("MASTER_ADDR", None)
if set_port:
os.environ.pop("MASTER_PORT", None)
for kenv in setting_env:
os.environ.pop(kenv, None)
return self

@override
Expand Down
6 changes: 1 addition & 5 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,7 @@ def _destroy_dist_connection() -> None:


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
"""Return corresponding distributed backend for a given device."""
device_backend_map = torch.distributed.Backend.default_device_backend_map
if device.type in device_backend_map:
return device_backend_map[device.type]
return "gloo"
return "nccl" if device.type == "cuda" else "gloo"


class _DatasetSamplerWrapper(Dataset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def test_memory_sharing_disabled(strategy):

def _test_memory_sharing_disabled(fabric, tensor, model):
is_spawn = fabric.strategy.launcher._start_method == "spawn"
assert not is_spawn or tensor.is_shared()
if is_spawn:
assert tensor.is_shared()
assert not model.layer.weight.is_shared()
assert not model.tied_layer.weight.is_shared()
assert not model.buffer.is_shared()
Expand Down
22 changes: 0 additions & 22 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
_get_default_process_group_backend_for_device,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
Expand Down Expand Up @@ -244,27 +243,6 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.register.assert_not_called()


def test_get_default_process_group_backend_for_device():
"""Test that each device type maps to its correct default process group backend."""
# register a custom backend for test
torch.utils.rename_privateuse1_backend("pcu")

def mock_backend(store, group_rank, group_size, timeout):
pass

torch.distributed.Backend.register_backend(
"pccl",
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
devices=["pcu"],
)

# test that the default backend is correctly set for each device
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
backends = ["gloo", "nccl", "pccl"]
for device, backend in zip(devices, backends):
assert _get_default_process_group_backend_for_device(device) == backend


@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor
Expand Down
Loading