Skip to content

Commit 9394b93

Browse files
kwen2501jerryzh168
authored andcommitted
Use DTensor.from instead of distribute_tensor
1 parent 8e4a0fc commit 9394b93

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,32 +41,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4141
# Test #
4242
########
4343
if __name__ == "__main__":
44+
# To make sure different ranks create the same module
45+
torch.manual_seed(5)
46+
4447
m = M()
45-
example_inputs = (100 * torch.randn(128, 1024),)
46-
m(*example_inputs)
48+
example_input = 100 * torch.randn(128, 1024)
49+
m(example_input)
4750

4851

4952
import os
50-
import torch
51-
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor
53+
from torch.distributed._tensor import DTensor, Replicate, Shard
5254
import torch.distributed as dist
55+
5356
# initialize a fake process group
54-
store = torch.testing._internal.distributed.fake_pg.FakeStore()
55-
dist.init_process_group(
56-
backend="fake",
57-
world_size=2,
58-
rank=0,
59-
store=store,
60-
)
61-
mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))
57+
world_size = int(os.environ["WORLD_SIZE"])
58+
rank = int(os.environ["RANK"])
59+
dist.init_process_group(backend="nccl")
60+
mesh = dist.init_device_mesh("cuda", (world_size,))
61+
6262
# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
63-
quantized_weight = to_my_dtype_tp(m.linear.weight)
63+
orig_weight = m.linear.weight
64+
quantized_weight = to_my_dtype_tp(orig_weight)
6465
print("quantized weight:", quantized_weight)
65-
quantized_weight_dtensor = distribute_tensor(quantized_weight, mesh, [Shard(dim=0)])
66-
print("quantized weight dtensor:", quantized_weight_dtensor)
66+
# Number of rows per rank
67+
n_local_rows = orig_weight.size(0) // world_size
68+
# TODO: add support for aten.slice.Tensor
69+
quantized_shard = quantized_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
70+
print("quantized shard:", quantized_shard)
71+
# Construct DTensor from local shard
72+
quantized_dtensor = DTensor.from_local(quantized_shard, device_mesh, [Shard(0)])
73+
print("quantized dtensor:", quantized_dtensor)
6774

75+
# Replace parameter in module
6876
m.linear.weight = torch.nn.Parameter(
69-
quantized_weight_dtensor, requires_grad=False
77+
quantized_dtensor, requires_grad=False
7078
)
7179

72-
m(*example_inputs)
80+
# We need to turn inputs into DTensor form as well -- just a format change
81+
input_dtensor = DTensor.from_local(
82+
example_input, mesh, [Replicate()]
83+
)
84+
print("input dtensor:", input_dtensor)
85+
86+
m(input_dtensor)
87+

0 commit comments

Comments
 (0)