@@ -41,32 +41,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4141# Test #
4242########
4343if __name__ == "__main__" :
44+ # To make sure different ranks create the same module
45+ torch .manual_seed (5 )
46+
4447 m = M ()
45- example_inputs = ( 100 * torch .randn (128 , 1024 ), )
46- m (* example_inputs )
48+ example_input = 100 * torch .randn (128 , 1024 )
49+ m (example_input )
4750
4851
4952 import os
50- import torch
51- from torch .distributed ._tensor import init_device_mesh , Shard , distribute_tensor
53+ from torch .distributed ._tensor import DTensor , Replicate , Shard
5254 import torch .distributed as dist
55+
5356 # initialize a fake process group
54- store = torch .testing ._internal .distributed .fake_pg .FakeStore ()
55- dist .init_process_group (
56- backend = "fake" ,
57- world_size = 2 ,
58- rank = 0 ,
59- store = store ,
60- )
61- mesh = init_device_mesh ("cuda" , (int (os .environ ["WORLD_SIZE" ]),))
57+ world_size = int (os .environ ["WORLD_SIZE" ])
58+ rank = int (os .environ ["RANK" ])
59+ dist .init_process_group (backend = "nccl" )
60+ mesh = dist .init_device_mesh ("cuda" , (world_size ,))
61+
6262 # Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
63- quantized_weight = to_my_dtype_tp (m .linear .weight )
63+ orig_weight = m .linear .weight
64+ quantized_weight = to_my_dtype_tp (orig_weight )
6465 print ("quantized weight:" , quantized_weight )
65- quantized_weight_dtensor = distribute_tensor (quantized_weight , mesh , [Shard (dim = 0 )])
66- print ("quantized weight dtensor:" , quantized_weight_dtensor )
66+ # Number of rows per rank
67+ n_local_rows = orig_weight .size (0 ) // world_size
68+ # TODO: add support for aten.slice.Tensor
69+ quantized_shard = quantized_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
70+ print ("quantized shard:" , quantized_shard )
71+ # Construct DTensor from local shard
72+ quantized_dtensor = DTensor .from_local (quantized_shard , device_mesh , [Shard (0 )])
73+ print ("quantized dtensor:" , quantized_dtensor )
6774
75+ # Replace parameter in module
6876 m .linear .weight = torch .nn .Parameter (
69- quantized_weight_dtensor , requires_grad = False
77+ quantized_dtensor , requires_grad = False
7078 )
7179
72- m (* example_inputs )
80+ # We need to turn inputs into DTensor form as well -- just a format change
81+ input_dtensor = DTensor .from_local (
82+ example_input , mesh , [Replicate ()]
83+ )
84+ print ("input dtensor:" , input_dtensor )
85+
86+ m (input_dtensor )
87+
0 commit comments