Skip to content

Commit 2870da5

Browse files
committed
Clean up
1 parent b44dbe0 commit 2870da5

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def _(func, types, args, kwargs):
7878
args[1],
7979
None
8080
)
81-
print("mm weight transposed:", weight_tensor.layout_tensor.transposed)
8281
weight_tensor = weight_tensor.dequantize()
8382
return aten.mm(input_tensor, weight_tensor)
8483

@@ -127,13 +126,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
127126
# Row-wise is wrt to A^T, so for A it is column-wise.
128127
# Number of rows per rank
129128
orig_weight = m.linear.weight
130-
print("rowwise original:", orig_weight.shape)
131129
n_local_cols = orig_weight.size(1) // mesh.size()
132130
rank = mesh.get_local_rank()
133-
print("rowwise n_local_cols:", n_local_cols)
134131
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
135-
# BUG: `local_shard` has the same shape as the original tensor
136-
print("rowwise local shard:", local_shard.shape)
137132
# Construct DTensor from local shard
138133
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
139134
# Replace parameter in module
@@ -156,9 +151,9 @@ def main():
156151
y = proj_dn(proj_up(example_input))
157152

158153
# Quantize the model
159-
q_up = quantize(proj_up)
160-
q_dn = quantize(proj_dn)
161-
y_q = q_dn(q_up(example_input))
154+
up_quant = quantize(proj_up)
155+
dn_quant = quantize(proj_dn)
156+
y_q = dn_quant(up_quant(example_input))
162157
print("Quantization works!")
163158

164159
# Create a device mesh
@@ -168,26 +163,22 @@ def main():
168163
mesh = dist.init_device_mesh("cuda", (world_size,))
169164

170165
# Shard the models
171-
d_up = colwise_shard(q_up, mesh)
172-
print("d_up weight shape:", d_up.linear.weight.shape)
173-
d_dn = rowwise_shard(q_dn, mesh)
166+
up_dist = colwise_shard(up_quant, mesh)
167+
dn_dist = rowwise_shard(dn_quant, mesh)
174168

175169
# We need to turn inputs into DTensor form as well -- just a format change
176170
input_dtensor = DTensor.from_local(
177171
example_input, mesh, [Replicate()]
178172
)
179173

180-
y_colwise = d_up(input_dtensor)
181-
print("y_colwise:", y_colwise.shape)
182-
print("result:", d_dn(y_colwise))
174+
y_d = dn_dist(up_dist(input_dtensor))
175+
print("Distributed result:", y_d)
183176
print("Distributed works!")
184177

185-
c_up = torch.compile(d_up)
186-
y_up = c_up(input_dtensor)
187-
print("y_up:", y_up.shape)
188-
c_dn = torch.compile(d_dn)
189-
y_dn = c_dn(y_up)
190-
print("y_dn:", y_dn.shape)
178+
up_compiled = torch.compile(up_dist)
179+
y_up = up_compiled(input_dtensor)
180+
dn_compiled = torch.compile(dn_dist)
181+
y_dn = dn_compiled(y_up)
191182
print("compiled result:", y_dn)
192183
print("torch.compile works!")
193184

0 commit comments

Comments
 (0)