Skip to content

Commit 1dee95a

Browse files
authored
Merge pull request #5 from daisyden/tdey/newskiptest-12ranks
[XPU] Added a new skip test to detect if the world size is set to non-power of 2
2 parents 7b81caa + 42401e4 commit 1dee95a

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

test/distributed/fsdp/test_fsdp_tp_integration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
parallelize_module,
2626
RowwiseParallel,
2727
)
28-
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
28+
from torch.testing._internal.common_distributed import (
29+
skip_if_lt_x_gpu,
30+
skip_if_not_powerof2_worldsize_xpu,
31+
32+
)
2933
from torch.testing._internal.common_fsdp import FSDPTest
3034
from torch.testing._internal.common_utils import (
3135
instantiate_parametrized_tests,
@@ -230,6 +234,7 @@ def _get_grads_as_flattened(
230234
return torch.cat(all_grads_per_param).contiguous()
231235

232236
@skip_if_lt_x_gpu(4)
237+
@skip_if_not_powerof2_worldsize_xpu()
233238
def test_fsdp_tp_integration(self):
234239
self.run_subtests(
235240
{

torch/testing/_internal/common_distributed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class TestSkip(NamedTuple):
8585
"importerror": TestSkip(88, "Test skipped due to missing import"),
8686
"no_accelerator": TestSkip(89, "accelerator is not available."),
8787
"not-support-multithread": TestSkip(90, "backend not support multithread."),
88+
"power-of-two": TestSkip(91, "world size needs to be power of two on xpu."),
8889
}
8990

9091

@@ -210,6 +211,23 @@ def wrapper(*args, **kwargs):
210211
return decorator
211212

212213

214+
def skip_if_not_powerof2_worldsize_xpu():
215+
def decorator(func):
216+
@wraps(func)
217+
def wrapper(*args, **kwargs):
218+
if TEST_XPU:
219+
x = torch.xpu.device_count()
220+
print(f"skip_if_not_powerof2_worldsize_xpu x: {x}")
221+
if (x & (x - 1)) == 0:
222+
return func(*args, **kwargs)
223+
sys.exit(TEST_SKIPS[f"power-of-two"].exit_code)
224+
return func(*args, **kwargs)
225+
226+
return wrapper
227+
228+
return decorator
229+
230+
213231
# This decorator helps avoiding initializing cuda while testing other backends
214232
def nccl_skip_if_lt_x_gpu(backend, x):
215233
def decorator(func):

0 commit comments

Comments
 (0)