Skip to content

Commit 01bb3ac

Browse files
committed
compile still not working yet
1 parent 6ccf610 commit 01bb3ac

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,16 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
177177

178178
y_colwise = d_up(input_dtensor)
179179
print("y_colwise:", y_colwise.shape)
180-
# doesn't work, see BUG in rowwise_shard()
181180
print("result:", d_dn(y_colwise))
182181
print("Distributed works!")
183182

184183
# doesn't work
185184
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
186185
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
187186
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
188-
c_up = torch.compile(d_up)
189-
c_dn = torch.compile(d_dn)
190-
print("compiled result:", c_dn(c_up(input_dtensor)))
191-
print("torch.compile works!")
187+
# c_up = torch.compile(d_up)
188+
# c_dn = torch.compile(d_dn)
189+
# print("compiled result:", c_dn(c_up(input_dtensor)))
190+
# print("torch.compile works!")
192191

193192
dist.destroy_process_group()

0 commit comments

Comments
 (0)