From 0dbdf5ffade791acb1fed80d29a5bb3256a6c18c Mon Sep 17 00:00:00 2001 From: Tanima Dey Date: Fri, 5 Sep 2025 22:13:27 +0000 Subject: [PATCH] unit testcase fix to run 8- and 12-ranks config --- .../_composable/fsdp/test_fully_shard_compile.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 6411c54cd13ec..eda1cd9e723e9 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -489,8 +489,13 @@ def test_compiled_autograd_ctx(self): torch._dynamo.config.patch(skip_fsdp_hooks=False), torch._functorch.config.patch(recompute_views=True), ): - inputs = torch.randn(8, 8) - model = torch.nn.Linear(8, 8) + device_type = torch.accelerator.current_accelerator().type + if device_type == "xpu" and self.world_size == 12: + dim = self.world_size + else: + dim = 8 + inputs = torch.randn(dim, dim) + model = torch.nn.Linear(dim, dim) fully_shard(model) model_compiled = torch.compile(model, backend="inductor") for i in range(10):