From cd29a80a3b01cdd5c6c5bca3cc0db817674da57f Mon Sep 17 00:00:00 2001 From: Tanima Dey Date: Tue, 10 Jun 2025 03:47:50 +0000 Subject: [PATCH 1/2] [XPU] Added a new skip test to detect if the world size is set to non-power-of-two for XPU and XCCL backend. The skip test relies on WORLD_SIZE being set. If WORLD_SIZE is not set, the test case is skipped. Used this skip-test for test/distributed/fsdp/test_fsdp_tp_integration.py as it is not designed when non-power-of-two, e.g., 12, ranks is used as WORLD_SIZE. The test case is not skipped when 4 or 8 ranks are used for XPU devices. The test case is not skipped for a non-XPU device. --- .../fsdp/test_fsdp_tp_integration.py | 7 ++++++- torch/testing/_internal/common_distributed.py | 21 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index 01d23a0e64e72..2d15d0c8d6890 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -25,7 +25,11 @@ parallelize_module, RowwiseParallel, ) -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + skip_if_lt_x_gpu, + skip_if_not_powerof2_worldsize_xpu, + +) from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -230,6 +234,7 @@ def _get_grads_as_flattened( return torch.cat(all_grads_per_param).contiguous() @skip_if_lt_x_gpu(4) + @skip_if_not_powerof2_worldsize_xpu() def test_fsdp_tp_integration(self): self.run_subtests( { diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 254dbbe2a2817..85e154d536af6 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -85,6 +85,8 @@ class TestSkip(NamedTuple): "importerror": TestSkip(88, "Test skipped due to missing import"), "no_accelerator": TestSkip(89, "accelerator is not available."), "not-support-multithread": TestSkip(90, "backend not support multithread."), + "power-of-two": TestSkip(91, "world size needs to be power of two on xpu."), + "worldsize-not-set": TestSkip(92, "world size not set"), } @@ -210,6 +212,25 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_not_powerof2_worldsize_xpu(): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if TEST_XPU: + if os.getenv("WORLD_SIZE") is not None: + x = int(os.getenv("WORLD_SIZE")) + print(f"skip_if_not_powerof2_worldsize_xpu x: {x}") + if TEST_XPU and ((x & (x - 1)) == 0): + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"power-of-two"].exit_code) + sys.exit(TEST_SKIPS[f"worldsize-not-set"].exit_code) + return func(*args, **kwargs) + + return wrapper + + return decorator + + # This decorator helps avoiding initializing cuda while testing other backends def nccl_skip_if_lt_x_gpu(backend, x): def decorator(func): From 42401e4b5cf52527324a0de11e40ec6d13a14b9a Mon Sep 17 00:00:00 2001 From: Tanima Dey Date: Sat, 14 Jun 2025 03:54:00 +0000 Subject: [PATCH 2/2] remove WORLD_SIZE checks --- torch/testing/_internal/common_distributed.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 85e154d536af6..bbc5dfdbfdfd9 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -86,7 +86,6 @@ class TestSkip(NamedTuple): "no_accelerator": TestSkip(89, "accelerator is not available."), "not-support-multithread": TestSkip(90, "backend not support multithread."), "power-of-two": TestSkip(91, "world size needs to be power of two on xpu."), - "worldsize-not-set": TestSkip(92, "world size not set"), } @@ -217,13 +216,11 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): if TEST_XPU: - if os.getenv("WORLD_SIZE") is not None: - x = int(os.getenv("WORLD_SIZE")) - print(f"skip_if_not_powerof2_worldsize_xpu x: {x}") - if TEST_XPU and ((x & (x - 1)) == 0): - return func(*args, **kwargs) - sys.exit(TEST_SKIPS[f"power-of-two"].exit_code) - sys.exit(TEST_SKIPS[f"worldsize-not-set"].exit_code) + x = torch.xpu.device_count() + print(f"skip_if_not_powerof2_worldsize_xpu x: {x}") + if (x & (x - 1)) == 0: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"power-of-two"].exit_code) return func(*args, **kwargs) return wrapper