diff --git a/test/test_fused_rms_norm.py b/test/test_dtensor_rmsnorm.py similarity index 63% rename from test/test_fused_rms_norm.py rename to test/test_dtensor_rmsnorm.py index 9bd7e3732c..da08f42a3d 100644 --- a/test/test_fused_rms_norm.py +++ b/test/test_dtensor_rmsnorm.py @@ -12,6 +12,7 @@ Shard, ) from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -19,17 +20,17 @@ with_comms, ) -from torchtitan.models.norms import fused_rms_norm_fn +from torchtitan.models.norms import fused_rms_norm_fn, RMSNorm -class TestFusedRMSNorm(DTensorTestBase): +class TestDTensorRMSNorm(DTensorTestBase): @property def world_size(self): return 4 @skip_if_lt_x_gpu(4) @with_comms - def test_fused_rms_norm(self): + def test_dtensor_fused_rmsnorm(self): mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) ) @@ -67,6 +68,34 @@ def test_fused_rms_norm(self): self.assertEqual(dist_out.full_tensor(), out) self.assertEqual(dist_grad_out.full_tensor(), grad_out) + @skip_if_lt_x_gpu(4) + @with_comms + def test_dtensor_compiled_rmsnorm(self): + mesh = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size,) + ) + x = torch.randn(1, 4, 4, device=self.device_type) + dist_x = distribute_tensor(x, mesh, [Shard(1)]) + + x = x.clone().detach() + self.assertEqual(dist_x.full_tensor(), x) + + dist_norm = RMSNorm(4, compile=True) + dist_norm.to(self.device_type) + norm = RMSNorm(4, compile=False) # single-gpu eager as baseline + norm.to(self.device_type) + + parallelize_module(dist_norm, mesh, SequenceParallel()) + + dist_out = dist_norm(dist_x) + out = norm(x) + + dist_out.sum().backward() + out.sum().backward() + + self.assertEqual(dist_out.full_tensor(), out) + self.assertEqual(dist_norm.weight.grad.full_tensor(), norm.weight.grad) + if __name__ == "__main__": run_tests()