11import os
22import torch
33import torch .distributed as dist
4+ from typing import Sequence
45from torch .distributed import DeviceMesh
5- from torch .distributed ._tensor import DTensor , Replicate , Shard
6+ from torch .distributed .tensor import DTensor , Replicate , Shard , Placement
67from torch .utils ._python_dispatch import return_and_correct_aliasing
78from my_dtype_tensor_subclass import MyDTypeTensor , fill_defaults
89
@@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
101102 )
102103 return m
103104
105+ def shard (
106+ full_tensor : torch .Tensor ,
107+ device_mesh : DeviceMesh ,
108+ placements : Sequence [Placement ],
109+ ) -> DTensor :
110+ """
111+ Add a shard function to simplify both colwise_shard and rowwise_shard. The
112+ shard function accepts a full tensor, and returns a DTensor based on
113+ indicated placements. Goal is to move the shard function as a static method
114+ of DTensor, e.g.
115+ dtensor = DTensor.shard(full_tensor, device_mesh, placement)
116+ """
117+ from torch .distributed .tensor ._utils import compute_local_shape_and_global_offset
118+
119+ shape , offset = compute_local_shape_and_global_offset (
120+ full_tensor .shape , device_mesh , placements
121+ )
122+ slices = [
123+ slice (cur_offset , cur_offset + cur_shape )
124+ for cur_shape , cur_offset in zip (shape , offset )
125+ ]
126+ local_tensor = full_tensor [slices ]
127+ return DTensor .from_local (
128+ local_tensor , device_mesh , placements
129+ )
130+
104131def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
105132 """
106133 Shard linear layer of the model in column-wise fashion
107134 """
108135 # Column-wise is wrt to A^T, so for A it is row-wise.
109- # Number of rows per rank
110136 orig_weight = m .linear .weight
111- n_local_rows = orig_weight .size (0 ) // mesh .size ()
112- rank = mesh .get_local_rank ()
113- local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
114137 # Construct DTensor from local shard
115- dtensor = DTensor . from_local ( local_shard , mesh , [Shard (0 )])
138+ dtensor = shard ( orig_weight , mesh , [Shard (0 )])
116139 # Replace parameter in module
117140 m .linear .weight = torch .nn .Parameter (
118141 dtensor , requires_grad = False
@@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
124147 Shard linear layer of the model in row-wise fashion
125148 """
126149 # Row-wise is wrt to A^T, so for A it is column-wise.
127- # Number of rows per rank
128150 orig_weight = m .linear .weight
129- n_local_cols = orig_weight .size (1 ) // mesh .size ()
130- rank = mesh .get_local_rank ()
131- local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
132151 # Construct DTensor from local shard
133- dtensor = DTensor . from_local ( local_shard , mesh , [Shard (1 )])
152+ dtensor = shard ( orig_weight , mesh , [Shard (1 )])
134153 # Replace parameter in module
135154 m .linear .weight = torch .nn .Parameter (
136155 dtensor , requires_grad = False
0 commit comments