File tree Expand file tree Collapse file tree 2 files changed +24
-1
lines changed Expand file tree Collapse file tree 2 files changed +24
-1
lines changed Original file line number Diff line number Diff line change 2525 parallelize_module ,
2626 RowwiseParallel ,
2727)
28- from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
28+ from torch .testing ._internal .common_distributed import (
29+ skip_if_lt_x_gpu ,
30+ skip_if_not_powerof2_worldsize_xpu ,
31+
32+ )
2933from torch .testing ._internal .common_fsdp import FSDPTest
3034from torch .testing ._internal .common_utils import (
3135 instantiate_parametrized_tests ,
@@ -230,6 +234,7 @@ def _get_grads_as_flattened(
230234 return torch .cat (all_grads_per_param ).contiguous ()
231235
232236 @skip_if_lt_x_gpu (4 )
237+ @skip_if_not_powerof2_worldsize_xpu ()
233238 def test_fsdp_tp_integration (self ):
234239 self .run_subtests (
235240 {
Original file line number Diff line number Diff line change @@ -85,6 +85,7 @@ class TestSkip(NamedTuple):
8585 "importerror" : TestSkip (88 , "Test skipped due to missing import" ),
8686 "no_accelerator" : TestSkip (89 , "accelerator is not available." ),
8787 "not-support-multithread" : TestSkip (90 , "backend not support multithread." ),
88+ "power-of-two" : TestSkip (91 , "world size needs to be power of two on xpu." ),
8889}
8990
9091
@@ -210,6 +211,23 @@ def wrapper(*args, **kwargs):
210211 return decorator
211212
212213
214+ def skip_if_not_powerof2_worldsize_xpu ():
215+ def decorator (func ):
216+ @wraps (func )
217+ def wrapper (* args , ** kwargs ):
218+ if TEST_XPU :
219+ x = torch .xpu .device_count ()
220+ print (f"skip_if_not_powerof2_worldsize_xpu x: { x } " )
221+ if (x & (x - 1 )) == 0 :
222+ return func (* args , ** kwargs )
223+ sys .exit (TEST_SKIPS [f"power-of-two" ].exit_code )
224+ return func (* args , ** kwargs )
225+
226+ return wrapper
227+
228+ return decorator
229+
230+
213231# This decorator helps avoiding initializing cuda while testing other backends
214232def nccl_skip_if_lt_x_gpu (backend , x ):
215233 def decorator (func ):
You can’t perform that action at this time.
0 commit comments