From 3ba5afbf98f33b5b7add3f59a8dc37a01c2b5b8c Mon Sep 17 00:00:00 2001 From: jemitche1 Date: Tue, 1 Jul 2025 12:54:53 -0700 Subject: [PATCH] fix ut for test_fully_shard_grad_scaler.py --- file.txt | 0 .../fsdp/test_fully_shard_grad_scaler.py | 115 +++++++++++++++++- 2 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 file.txt diff --git a/file.txt b/file.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py index edf556b847f86..01615a2df1c26 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] import copy +from torch.distributed.tensor import distribute_tensor, Shard import torch import torch.nn as nn @@ -15,6 +16,7 @@ from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP from torch.testing._internal.common_utils import run_tests +device_type = torch.accelerator.current_accelerator().type device_type = torch.device(get_devtype()) @@ -26,7 +28,110 @@ def test_gradient_scaler(self): {"has_inf": [True, False], "test_2d": [True, False]}, self._test_gradient_scaler, ) + + def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): + torch.manual_seed(0) + + def refactor_twoWay_parallel(n): + return max( + ((i, n // i) for i in range(1, int(n ** 0.5) + 1) if n % i == 0), + key=lambda x: min(x) + ) + + if test_2d: + # Dynamically compute a balanced 2D mesh + dp_size, tp_size = refactor_twoWay_parallel(self.world_size) + mesh_2d = init_device_mesh( + device_type.type, (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"] + + # Ensure model dimensions are divisible by tp_size + model_dim = tp_size # in_features and out_features will both match tp size + model = nn.Sequential(MLP(model_dim), MLP(model_dim), MLP(model_dim)) + + tp_parallelize_plan = { + "0.in_proj": ColwiseParallel(), + "0.out_proj": RowwiseParallel(), + "1.in_proj": ColwiseParallel(), + "1.out_proj": RowwiseParallel(), + "2.in_proj": ColwiseParallel(), + "2.out_proj": RowwiseParallel(), + } + + model = parallelize_module( + model, + device_mesh=tp_mesh, + parallelize_plan=tp_parallelize_plan, + ) + for module in model: + fully_shard(module, mesh=dp_mesh) + fully_shard(model, mesh=dp_mesh) + + # Ensure input shape is (batch_size, feature_dim) + input = torch.randn((2, model_dim), device=device_type) + + else: + # Default path: single-dimension FSDP + model = nn.Sequential( + *[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)] + ) + for layer in model: + fully_shard(layer) + fully_shard(model) + input = torch.randn([4, 4], device=device_type) + + loss = model(input).sum() + scaler = GradScaler(init_scale=2.0, enabled=True, device=device_type.type) + opt = torch.optim.Adam(model.parameters(), lr=1e-2) + scaler.scale(loss).backward() + inv_scale = scaler._scale.double().reciprocal().float() + + if ( + has_inf is True + and opt.param_groups[0]["params"][0].grad._local_tensor.device.index == 1 + ): + opt.param_groups[0]["params"][0].grad._local_tensor[0, 0].fill_(float("inf")) + + inital_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() + + scaler.unscale_(opt) + for found_inf in scaler._per_optimizer_states[id(opt)][ + "found_inf_per_device" + ].values(): + self.assertEqual(found_inf, has_inf) + + self.assertEqual( + scaler._per_optimizer_states[id(opt)]["stage"].value, + OptState.UNSCALED.value, + ) + + unscaled_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() + self.assertEqual(unscaled_grad, inital_grad * inv_scale) + initial_scale = scaler.get_scale() + initial_state = copy.copy(opt.state) + + scaler.step(opt) + steped_state = copy.copy(opt.state) + + if has_inf: + self.assertEqual(steped_state, initial_state) + else: + self.assertNotEqual(steped_state.items(), initial_state.items()) + + scaler.update() + updated_scale = scaler.get_scale() + + if has_inf: + backoff_factor = scaler.get_backoff_factor() + self.assertEqual(updated_scale, initial_scale * backoff_factor) + else: + self.assertEqual(updated_scale, initial_scale) + + + + ''' def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): torch.manual_seed(0) model = nn.Sequential( @@ -35,12 +140,14 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): for layer in model: fully_shard(layer) fully_shard(model) - input = torch.randn([4, 4], device=device_type) + input = torch.randn([4,4], device=device_type) if test_2d: mesh_2d = init_device_mesh( - device_type.type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") - ) + #device_type.type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + + device_type.type, (2,6), mesh_dim_names=("dp", "tp") + ) dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"] model = nn.Sequential(MLP(2), MLP(2), MLP(2)) tp_parallelize_plan = { @@ -108,6 +215,6 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): # scale is not updated self.assertEqual(updated_scale, initial_scale) - + ''' if __name__ == "__main__": run_tests()