diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 6411c54cd13e..eda1cd9e723e 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -489,8 +489,13 @@ def test_compiled_autograd_ctx(self): torch._dynamo.config.patch(skip_fsdp_hooks=False), torch._functorch.config.patch(recompute_views=True), ): - inputs = torch.randn(8, 8) - model = torch.nn.Linear(8, 8) + device_type = torch.accelerator.current_accelerator().type + if device_type == "xpu" and self.world_size == 12: + dim = self.world_size + else: + dim = 8 + inputs = torch.randn(dim, dim) + model = torch.nn.Linear(dim, dim) fully_shard(model) model_compiled = torch.compile(model, backend="inductor") for i in range(10):