@@ -57,10 +57,8 @@ def _(func, types, args, kwargs):
5757@implements (aten .t .default )
5858def _ (func , types , args , kwargs ):
5959 tensor = args [0 ]
60- print ("before transpose, " , tensor .shape )
6160 shape = tensor .shape [::- 1 ]
6261 new = tensor .__class__ (tensor .layout_tensor .t (), shape , tensor .dtype )
63- print ("after transpose:" , new .shape )
6462 return return_and_correct_aliasing (func , args , kwargs , new )
6563
6664@implements (aten .addmm .default )
@@ -80,8 +78,7 @@ def _(func, types, args, kwargs):
8078 args [1 ],
8179 None
8280 )
83- print ("mm input tensor shape:" , input_tensor .shape )
84- print ("mm weight tensor shape:" , weight_tensor .shape )
81+ print ("mm weight transposed:" , weight_tensor .layout_tensor .transposed )
8582 weight_tensor = weight_tensor .dequantize ()
8683 return aten .mm (input_tensor , weight_tensor )
8784
@@ -172,6 +169,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
172169
173170 # Shard the models
174171 d_up = colwise_shard (q_up , mesh )
172+ print ("d_up weight shape:" , d_up .linear .weight .shape )
175173 d_dn = rowwise_shard (q_dn , mesh )
176174
177175 # We need to turn inputs into DTensor form as well -- just a format change
@@ -188,10 +186,10 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
188186 # [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
189187 # 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
190188 # [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
191- c_up = torch .compile (d_up , backend = "eager" )
189+ c_up = torch .compile (d_up )
192190 y_up = c_up (input_dtensor )
193191 print ("y_up:" , y_up .shape )
194- c_dn = torch .compile (d_dn , backend = "eager" )
192+ c_dn = torch .compile (d_dn )
195193 y_dn = c_dn (y_up )
196194 print ("y_dn:" , y_dn .shape )
197195 print ("compiled result:" , y_dn )
0 commit comments