diff --git a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py index 2b99bebe51c5b..f5d0badacba27 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py @@ -177,7 +177,7 @@ def test_dp_state_dict_cpu_offload(self): def _test_dp_state_dict_cpu_offload( self, offload_policy: CPUOffloadPolicy, cpu_state_dict: bool ): - mlp_dim = 4 + mlp_dim = self.world_size torch.manual_seed(42) with torch.device("meta"): model = nn.Sequential(