@@ -85,7 +85,7 @@ def _(func, types, args, kwargs):
8585class M (torch .nn .Module ):
8686 def __init__ (self , in_features , out_features , ** kwargs ) -> None :
8787 super ().__init__ (** kwargs )
88- self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
88+ self .linear = torch .nn .Linear (in_features , out_features , bias = False )
8989
9090 def forward (self , x : torch .Tensor ) -> torch .Tensor :
9191 return self .linear (x )
@@ -144,10 +144,15 @@ def main():
144144 # To make sure different ranks create the same module
145145 torch .manual_seed (5 )
146146
147+ # Get rank and device
148+ world_size = int (os .environ ["WORLD_SIZE" ])
149+ rank = int (os .environ ["RANK" ])
150+ device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
151+
147152 # Original model
148- proj_up = M (1024 , 2048 )
149- proj_dn = M (2048 , 1024 )
150- example_input = 100 * torch .randn (128 , 1024 , device = "cuda" )
153+ proj_up = M (1024 , 2048 ). to ( device )
154+ proj_dn = M (2048 , 1024 ). to ( device )
155+ example_input = 100 * torch .randn (128 , 1024 , device = device )
151156 y = proj_dn (proj_up (example_input ))
152157
153158 # Quantize the model
@@ -157,8 +162,6 @@ def main():
157162 print ("Quantization works!" )
158163
159164 # Create a device mesh
160- world_size = int (os .environ ["WORLD_SIZE" ])
161- rank = int (os .environ ["RANK" ])
162165 dist .init_process_group (backend = "nccl" )
163166 mesh = dist .init_device_mesh ("cuda" , (world_size ,))
164167
0 commit comments