Skip to content

Commit 1074ecf

Browse files
committed
Fix device id
1 parent 2870da5 commit 1074ecf

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _(func, types, args, kwargs):
8585
class M(torch.nn.Module):
8686
def __init__(self, in_features, out_features, **kwargs) -> None:
8787
super().__init__(**kwargs)
88-
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
88+
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
8989

9090
def forward(self, x: torch.Tensor) -> torch.Tensor:
9191
return self.linear(x)
@@ -144,10 +144,15 @@ def main():
144144
# To make sure different ranks create the same module
145145
torch.manual_seed(5)
146146

147+
# Get rank and device
148+
world_size = int(os.environ["WORLD_SIZE"])
149+
rank = int(os.environ["RANK"])
150+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
151+
147152
# Original model
148-
proj_up = M(1024, 2048)
149-
proj_dn = M(2048, 1024)
150-
example_input = 100 * torch.randn(128, 1024, device="cuda")
153+
proj_up = M(1024, 2048).to(device)
154+
proj_dn = M(2048, 1024).to(device)
155+
example_input = 100 * torch.randn(128, 1024, device=device)
151156
y = proj_dn(proj_up(example_input))
152157

153158
# Quantize the model
@@ -157,8 +162,6 @@ def main():
157162
print("Quantization works!")
158163

159164
# Create a device mesh
160-
world_size = int(os.environ["WORLD_SIZE"])
161-
rank = int(os.environ["RANK"])
162165
dist.init_process_group(backend="nccl")
163166
mesh = dist.init_device_mesh("cuda", (world_size,))
164167

0 commit comments

Comments
 (0)