Skip to content

Commit 859ddce

Browse files
committed
fake tensor didn't pick up shape changes from transpose
1 parent dc217fe commit 859ddce

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def _(func, types, args, kwargs):
5757
@implements(aten.t.default)
5858
def _(func, types, args, kwargs):
5959
tensor = args[0]
60+
print("before transpose, ", tensor.shape)
6061
shape = tensor.shape[::-1]
6162
new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
63+
print("after transpose:", new.shape)
6264
return return_and_correct_aliasing(func, args, kwargs, new)
6365

6466
@implements(aten.addmm.default)
@@ -78,6 +80,8 @@ def _(func, types, args, kwargs):
7880
args[1],
7981
None
8082
)
83+
print("input tensor shape:", input_tensor.shape)
84+
print("weight tensor shape:", weight_tensor.shape)
8185
weight_tensor = weight_tensor.dequantize()
8286
return aten.mm(input_tensor, weight_tensor)
8387

@@ -184,9 +188,13 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
184188
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
185189
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
186190
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
187-
# c_up = torch.compile(d_up)
188-
# c_dn = torch.compile(d_dn)
189-
# print("compiled result:", c_dn(c_up(input_dtensor)))
190-
# print("torch.compile works!")
191+
c_up = torch.compile(d_up)
192+
y_up = c_up(input_dtensor)
193+
print("y_up:", y_up.shape)
194+
c_dn = torch.compile(d_dn)
195+
y_dn = c_dn(y_up)
196+
print("y_dn:", y_dn.shape)
197+
print("compiled result:", y_dn)
198+
print("torch.compile works!")
191199

192200
dist.destroy_process_group()

0 commit comments

Comments
 (0)