Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions test/test_fused_rms_norm.py → test/test_dtensor_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)

from torchtitan.models.norms import fused_rms_norm_fn
from torchtitan.models.norms import fused_rms_norm_fn, RMSNorm


class TestFusedRMSNorm(DTensorTestBase):
class TestDTensorRMSNorm(DTensorTestBase):
@property
def world_size(self):
return 4

@skip_if_lt_x_gpu(4)
@with_comms
def test_fused_rms_norm(self):
def test_dtensor_fused_rmsnorm(self):
mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
Expand Down Expand Up @@ -67,6 +68,34 @@ def test_fused_rms_norm(self):
self.assertEqual(dist_out.full_tensor(), out)
self.assertEqual(dist_grad_out.full_tensor(), grad_out)

@skip_if_lt_x_gpu(4)
@with_comms
def test_dtensor_compiled_rmsnorm(self):
mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
x = torch.randn(1, 4, 4, device=self.device_type)
dist_x = distribute_tensor(x, mesh, [Shard(1)])

x = x.clone().detach()
self.assertEqual(dist_x.full_tensor(), x)

dist_norm = RMSNorm(4, compile=True)
dist_norm.to(self.device_type)
norm = RMSNorm(4, compile=False) # single-gpu eager as baseline
norm.to(self.device_type)

parallelize_module(dist_norm, mesh, SequenceParallel())

dist_out = dist_norm(dist_x)
out = norm(x)

dist_out.sum().backward()
out.sum().backward()

self.assertEqual(dist_out.full_tensor(), out)
self.assertEqual(dist_norm.weight.grad.full_tensor(), norm.weight.grad)


if __name__ == "__main__":
run_tests()