diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index d1b883be1fb32..edaedadcee5bc 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -942,7 +942,7 @@ def test_split_tensor_1D(self) -> None: mesh = self.build_device_mesh() shard_placement = Shard(0) - for size in range(8): + for size in range(self.world_size): tensor = self._create_tensor(size) splitted_tensor_list, pad_sizes = shard_placement._split_tensor( tensor,