|
12 | 12 | Shard, |
13 | 13 | ) |
14 | 14 | from torch.distributed._tensor.debug import CommDebugMode |
| 15 | +from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel |
15 | 16 | from torch.testing._internal.common_utils import run_tests |
16 | 17 | from torch.testing._internal.distributed._tensor.common_dtensor import ( |
17 | 18 | DTensorTestBase, |
18 | 19 | skip_if_lt_x_gpu, |
19 | 20 | with_comms, |
20 | 21 | ) |
21 | 22 |
|
22 | | -from torchtitan.models.norms import fused_rms_norm_fn |
| 23 | +from torchtitan.models.norms import fused_rms_norm_fn, RMSNorm |
23 | 24 |
|
24 | 25 |
|
25 | | -class TestFusedRMSNorm(DTensorTestBase): |
| 26 | +class TestDTensorRMSNorm(DTensorTestBase): |
26 | 27 | @property |
27 | 28 | def world_size(self): |
28 | 29 | return 4 |
29 | 30 |
|
30 | 31 | @skip_if_lt_x_gpu(4) |
31 | 32 | @with_comms |
32 | | - def test_fused_rms_norm(self): |
| 33 | + def test_dtensor_fused_rmsnorm(self): |
33 | 34 | mesh = init_device_mesh( |
34 | 35 | device_type=self.device_type, mesh_shape=(self.world_size,) |
35 | 36 | ) |
@@ -67,6 +68,34 @@ def test_fused_rms_norm(self): |
67 | 68 | self.assertEqual(dist_out.full_tensor(), out) |
68 | 69 | self.assertEqual(dist_grad_out.full_tensor(), grad_out) |
69 | 70 |
|
| 71 | + @skip_if_lt_x_gpu(4) |
| 72 | + @with_comms |
| 73 | + def test_dtensor_compiled_rmsnorm(self): |
| 74 | + mesh = init_device_mesh( |
| 75 | + device_type=self.device_type, mesh_shape=(self.world_size,) |
| 76 | + ) |
| 77 | + x = torch.randn(1, 4, 4, device=self.device_type) |
| 78 | + dist_x = distribute_tensor(x, mesh, [Shard(1)]) |
| 79 | + |
| 80 | + x = x.clone().detach() |
| 81 | + self.assertEqual(dist_x.full_tensor(), x) |
| 82 | + |
| 83 | + dist_norm = RMSNorm(4, compile=True) |
| 84 | + dist_norm.to(self.device_type) |
| 85 | + norm = RMSNorm(4, compile=False) # single-gpu eager as baseline |
| 86 | + norm.to(self.device_type) |
| 87 | + |
| 88 | + parallelize_module(dist_norm, mesh, SequenceParallel()) |
| 89 | + |
| 90 | + dist_out = dist_norm(dist_x) |
| 91 | + out = norm(x) |
| 92 | + |
| 93 | + dist_out.sum().backward() |
| 94 | + out.sum().backward() |
| 95 | + |
| 96 | + self.assertEqual(dist_out.full_tensor(), out) |
| 97 | + self.assertEqual(dist_norm.weight.grad.full_tensor(), norm.weight.grad) |
| 98 | + |
70 | 99 |
|
71 | 100 | if __name__ == "__main__": |
72 | 101 | run_tests() |
0 commit comments