@@ -57,8 +57,10 @@ def _(func, types, args, kwargs):
5757@implements (aten .t .default )
5858def _ (func , types , args , kwargs ):
5959 tensor = args [0 ]
60+ print ("before transpose, " , tensor .shape )
6061 shape = tensor .shape [::- 1 ]
6162 new = tensor .__class__ (tensor .layout_tensor .t (), shape , tensor .dtype )
63+ print ("after transpose:" , new .shape )
6264 return return_and_correct_aliasing (func , args , kwargs , new )
6365
6466@implements (aten .addmm .default )
@@ -78,6 +80,8 @@ def _(func, types, args, kwargs):
7880 args [1 ],
7981 None
8082 )
83+ print ("input tensor shape:" , input_tensor .shape )
84+ print ("weight tensor shape:" , weight_tensor .shape )
8185 weight_tensor = weight_tensor .dequantize ()
8286 return aten .mm (input_tensor , weight_tensor )
8387
@@ -184,9 +188,13 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
184188 # [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
185189 # 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
186190 # [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
187- # c_up = torch.compile(d_up)
188- # c_dn = torch.compile(d_dn)
189- # print("compiled result:", c_dn(c_up(input_dtensor)))
190- # print("torch.compile works!")
191+ c_up = torch .compile (d_up )
192+ y_up = c_up (input_dtensor )
193+ print ("y_up:" , y_up .shape )
194+ c_dn = torch .compile (d_dn )
195+ y_dn = c_dn (y_up )
196+ print ("y_dn:" , y_dn .shape )
197+ print ("compiled result:" , y_dn )
198+ print ("torch.compile works!" )
191199
192200 dist .destroy_process_group ()
0 commit comments