Skip to content

let _get_default_process_group_backend_for_device support more hardware platforms #21057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
"""Return corresponding distributed backend for a given device."""
device_backend_map = torch.distributed.Backend.default_device_backend_map
if device.type in device_backend_map:
return device_backend_map[device.type]
return "gloo"


class _DatasetSamplerWrapper(Dataset):
Expand Down
22 changes: 22 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
_get_default_process_group_backend_for_device,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
Expand Down Expand Up @@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.register.assert_not_called()


def test_get_default_process_group_backend_for_device():
"""Test that each device type maps to its correct default process group backend."""
# register a custom backend for test
torch.utils.rename_privateuse1_backend("pcu")

def mock_backend(store, group_rank, group_size, timeout):
pass

torch.distributed.Backend.register_backend(
"pccl",
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
devices=["pcu"],
)

# test that the default backend is correctly set for each device
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
backends = ["gloo", "nccl", "pccl"]
for device, backend in zip(devices, backends):
assert _get_default_process_group_backend_for_device(device) == backend


@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor
Expand Down
Loading