Skip to content

Commit 5613c12

Browse files
committed
add unit test for compiled_rmsnorm
ghstack-source-id: d3e5dbd Pull Request resolved: #506
1 parent f9e114b commit 5613c12

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

test/test_fused_rms_norm.py renamed to test/test_dtensor_rmsnorm.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,25 @@
1212
Shard,
1313
)
1414
from torch.distributed._tensor.debug import CommDebugMode
15+
from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
1516
from torch.testing._internal.common_utils import run_tests
1617
from torch.testing._internal.distributed._tensor.common_dtensor import (
1718
DTensorTestBase,
1819
skip_if_lt_x_gpu,
1920
with_comms,
2021
)
2122

22-
from torchtitan.models.norms import fused_rms_norm_fn
23+
from torchtitan.models.norms import fused_rms_norm_fn, RMSNorm
2324

2425

25-
class TestFusedRMSNorm(DTensorTestBase):
26+
class TestDTensorRMSNorm(DTensorTestBase):
2627
@property
2728
def world_size(self):
2829
return 4
2930

3031
@skip_if_lt_x_gpu(4)
3132
@with_comms
32-
def test_fused_rms_norm(self):
33+
def test_dtensor_fused_rmsnorm(self):
3334
mesh = init_device_mesh(
3435
device_type=self.device_type, mesh_shape=(self.world_size,)
3536
)
@@ -67,6 +68,34 @@ def test_fused_rms_norm(self):
6768
self.assertEqual(dist_out.full_tensor(), out)
6869
self.assertEqual(dist_grad_out.full_tensor(), grad_out)
6970

71+
@skip_if_lt_x_gpu(4)
72+
@with_comms
73+
def test_dtensor_compiled_rmsnorm(self):
74+
mesh = init_device_mesh(
75+
device_type=self.device_type, mesh_shape=(self.world_size,)
76+
)
77+
x = torch.randn(1, 4, 4, device=self.device_type)
78+
dist_x = distribute_tensor(x, mesh, [Shard(1)])
79+
80+
x = x.clone().detach()
81+
self.assertEqual(dist_x.full_tensor(), x)
82+
83+
dist_norm = RMSNorm(4, compile=True)
84+
dist_norm.to(self.device_type)
85+
norm = RMSNorm(4, compile=False) # single-gpu eager as baseline
86+
norm.to(self.device_type)
87+
88+
parallelize_module(dist_norm, mesh, SequenceParallel())
89+
90+
dist_out = dist_norm(dist_x)
91+
out = norm(x)
92+
93+
dist_out.sum().backward()
94+
out.sum().backward()
95+
96+
self.assertEqual(dist_out.full_tensor(), out)
97+
self.assertEqual(dist_norm.weight.grad.full_tensor(), norm.weight.grad)
98+
7099

71100
if __name__ == "__main__":
72101
run_tests()

0 commit comments

Comments
 (0)