@@ -250,7 +250,6 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
250250 m .linear .weight = torch .nn .Parameter (
251251 dtensor , requires_grad = False
252252 )
253- print ('colwise shard Shapeof m.linear.weight : ' , m .linear .weight .shape )
254253 return m
255254
256255 @staticmethod
@@ -265,15 +264,11 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
265264 rank = mesh .get_local_rank ()
266265 local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
267266 # Construct DTensor from local shard
268- dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
269- print (f'dtensor shape: { dtensor .shape } ' )
270- print (f'Other dtensor values: { local_shard .original_weight_tensor .tensor_impl .float8_data .shape } , { mesh } , { [Shard (1 )]} ' )
267+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )])
271268 # Replace parameter in module
272269 m .linear .weight = torch .nn .Parameter (
273270 dtensor , requires_grad = False
274271 )
275- print ('rowwise shard Shapeof m.linear.weight : ' , m .linear .weight .shape )
276-
277272 return m
278273
279274 def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
@@ -306,15 +301,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
306301 proj_up = M (1024 , 2048 ).to (device ).to (dtype )
307302 proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
308303 example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
304+ print ('Run y' )
309305 y = proj_dn (proj_up (example_input ))
310- print ( 'Run before y' )
306+
311307 # Quantize the model
312308 up_quant = self .quantize (proj_up )
313309 dn_quant = self .quantize (proj_dn )
314- print ('Run before y_q' )
315310 y_q = dn_quant (up_quant (example_input ))
316- print ('Executed y_q' )
317-
311+
318312 mesh = self .build_device_mesh ()
319313 mesh .device_type = "cuda"
320314
0 commit comments