Skip to content

Commit 397fb27

Browse files
shintaro-iwasakipytorchmergebot
authored andcommitted
[DTensor] Fix DeviceMesh (pytorch#96861)
Summary: This Diff fixes some DeviceMesh issues, which blocks internal DTensor integration. Specifically, when `self.mesh = [2, 3]` while `world_size = 4`, because `unique_mesh_values[-1] == 3`, it takes the first short-cut branch and uses `default_pg`. Let's check the length instead of the last value of `unique_mesh_values`. Test Plan: CI Reviewed By: wanchaol Differential Revision: D44079872 Pull Request resolved: pytorch#96861 Approved by: https://github.com/wanchaol
1 parent 6718e3c commit 397fb27

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/distributed/_tensor/device_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __init__(
184184
self._dim_groups = dim_groups
185185
return
186186

187-
if self.mesh.ndim == 1 and unique_mesh_values[-1] == world_size - 1:
187+
if self.mesh.ndim == 1 and len(unique_mesh_values) == world_size - 1:
188188
# if the mesh is the same as world_pg, we just append the default
189189
# pg to the first dim goups, as new_group cannot have the exact
190190
# same ranks as world

0 commit comments

Comments
 (0)