@@ -78,7 +78,6 @@ def _(func, types, args, kwargs):
7878 args [1 ],
7979 None
8080 )
81- print ("mm weight transposed:" , weight_tensor .layout_tensor .transposed )
8281 weight_tensor = weight_tensor .dequantize ()
8382 return aten .mm (input_tensor , weight_tensor )
8483
@@ -127,13 +126,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
127126 # Row-wise is wrt to A^T, so for A it is column-wise.
128127 # Number of rows per rank
129128 orig_weight = m .linear .weight
130- print ("rowwise original:" , orig_weight .shape )
131129 n_local_cols = orig_weight .size (1 ) // mesh .size ()
132130 rank = mesh .get_local_rank ()
133- print ("rowwise n_local_cols:" , n_local_cols )
134131 local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
135- # BUG: `local_shard` has the same shape as the original tensor
136- print ("rowwise local shard:" , local_shard .shape )
137132 # Construct DTensor from local shard
138133 dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )])
139134 # Replace parameter in module
@@ -156,9 +151,9 @@ def main():
156151 y = proj_dn (proj_up (example_input ))
157152
158153 # Quantize the model
159- q_up = quantize (proj_up )
160- q_dn = quantize (proj_dn )
161- y_q = q_dn ( q_up (example_input ))
154+ up_quant = quantize (proj_up )
155+ dn_quant = quantize (proj_dn )
156+ y_q = dn_quant ( up_quant (example_input ))
162157 print ("Quantization works!" )
163158
164159 # Create a device mesh
@@ -168,26 +163,22 @@ def main():
168163 mesh = dist .init_device_mesh ("cuda" , (world_size ,))
169164
170165 # Shard the models
171- d_up = colwise_shard (q_up , mesh )
172- print ("d_up weight shape:" , d_up .linear .weight .shape )
173- d_dn = rowwise_shard (q_dn , mesh )
166+ up_dist = colwise_shard (up_quant , mesh )
167+ dn_dist = rowwise_shard (dn_quant , mesh )
174168
175169 # We need to turn inputs into DTensor form as well -- just a format change
176170 input_dtensor = DTensor .from_local (
177171 example_input , mesh , [Replicate ()]
178172 )
179173
180- y_colwise = d_up (input_dtensor )
181- print ("y_colwise:" , y_colwise .shape )
182- print ("result:" , d_dn (y_colwise ))
174+ y_d = dn_dist (up_dist (input_dtensor ))
175+ print ("Distributed result:" , y_d )
183176 print ("Distributed works!" )
184177
185- c_up = torch .compile (d_up )
186- y_up = c_up (input_dtensor )
187- print ("y_up:" , y_up .shape )
188- c_dn = torch .compile (d_dn )
189- y_dn = c_dn (y_up )
190- print ("y_dn:" , y_dn .shape )
178+ up_compiled = torch .compile (up_dist )
179+ y_up = up_compiled (input_dtensor )
180+ dn_compiled = torch .compile (dn_dist )
181+ y_dn = dn_compiled (y_up )
191182 print ("compiled result:" , y_dn )
192183 print ("torch.compile works!" )
193184
0 commit comments