Skip to content

Commit 3dfc799

Browse files
committed
Added support for PerRow granularity
1 parent 26d84b5 commit 3dfc799

File tree

4 files changed

+145
-40
lines changed

4 files changed

+145
-40
lines changed
Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
import torch
2+
import unittest
23
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
34
from torch.testing._internal.common_utils import run_tests
5+
from torch.testing._internal import common_utils
46
from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight
57
from torchao.quantization.observer import PerRow, PerTensor
8+
import torch.distributed as dist
9+
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
10+
from torch.testing._internal.distributed._tensor.common_dtensor import (
11+
DTensorTestBase,
12+
with_comms,
13+
NUM_DEVICES,
14+
)
15+
from torchao.quantization.quant_api import quantize_
16+
from torchao.dtypes import AffineQuantizedTensor
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
618

719
class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
820
QUANT_METHOD_FN = staticmethod(int8_weight_only)
@@ -16,17 +28,131 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
1628

1729
# Run only on H100
1830
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
19-
class TestFloat8dqTensorAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
31+
class TestFloat8dqAffineQuantizedTensorParallel(DTensorTestBase):
32+
"""Basic test case for tensor subclasses
33+
"""
34+
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
35+
TENSOR_SUBCLASS = AffineQuantizedTensor
36+
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
37+
QUANT_METHOD_KWARGS = {}
38+
39+
@staticmethod
40+
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
41+
"""
42+
Shard linear layer of the model in column-wise fashion
43+
"""
44+
# Column-wise is wrt to A^T, so for A it is row-wise.
45+
# Number of rows per rank
46+
orig_weight = m.linear.weight
47+
n_local_rows = orig_weight.size(0) // mesh.size()
48+
rank = mesh.get_local_rank()
49+
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
50+
# Construct DTensor from local shard
51+
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
52+
# Replace parameter in module
53+
m.linear.weight = torch.nn.Parameter(
54+
dtensor, requires_grad=False
55+
)
56+
return m
57+
58+
@staticmethod
59+
def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
60+
"""
61+
Shard linear layer of the model in row-wise fashion
62+
"""
63+
# Row-wise is wrt to A^T, so for A it is column-wise.
64+
# Number of rows per rank
65+
orig_weight = m.linear.weight
66+
n_local_cols = orig_weight.size(1) // mesh.size()
67+
rank = mesh.get_local_rank()
68+
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
69+
# Construct DTensor from local shard
70+
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
71+
# Replace parameter in module
72+
m.linear.weight = torch.nn.Parameter(
73+
dtensor, requires_grad=False
74+
)
75+
return m
76+
77+
def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
78+
"""
79+
Quantize the model
80+
"""
81+
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
82+
return m
83+
84+
def _test_tp(self, dtype):
85+
device = "cuda"
86+
# To make sure different ranks create the same module
87+
torch.manual_seed(5)
88+
89+
class M(torch.nn.Module):
90+
def __init__(self, in_features, out_features, **kwargs) -> None:
91+
super().__init__(**kwargs)
92+
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
93+
94+
def forward(self, x: torch.Tensor) -> torch.Tensor:
95+
return self.linear(x)
96+
97+
# Get rank and device
98+
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
99+
100+
# Original model
101+
proj_up = M(1024, 2048).to(device).to(dtype)
102+
proj_dn = M(2048, 1024).to(device).to(dtype)
103+
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
104+
y = proj_dn(proj_up(example_input))
105+
# Quantize the model
106+
up_quant = self.quantize(proj_up)
107+
dn_quant = self.quantize(proj_dn)
108+
y_q = dn_quant(up_quant(example_input))
109+
110+
mesh = self.build_device_mesh()
111+
mesh.device_type = "cuda"
112+
113+
# Shard the models
114+
up_dist = self.colwise_shard(up_quant, mesh)
115+
dn_dist = self.rowwise_shard(dn_quant, mesh)
116+
117+
# We need to turn inputs into DTensor form as well -- just a format change
118+
input_dtensor = DTensor.from_local(
119+
example_input, mesh, [Replicate()]
120+
)
121+
122+
y_d = dn_dist(up_dist(input_dtensor))
123+
124+
if not TORCH_VERSION_AT_LEAST_2_5:
125+
# Need torch 2.5 to support compiled tensor parallelism
126+
return
127+
128+
up_compiled = torch.compile(up_dist)
129+
y_up = up_compiled(input_dtensor)
130+
dn_compiled = torch.compile(dn_dist)
131+
y_dn = dn_compiled(y_up)
132+
133+
class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel):
20134
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
21135
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
22-
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqTensorAffineQuantizedTensorParallel, "fp8dqt_tp")
136+
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
23137

24-
# Run only on H100
25-
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
26-
class TestFloat8dqRowAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
138+
@common_utils.parametrize("dtype", COMMON_DTYPES)
139+
@with_comms
140+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
141+
def test_tp(self, dtype):
142+
return self._test_tp(dtype)
143+
144+
class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel):
27145
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
28146
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
29-
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqRowAffineQuantizedTensorParallel, "fp8dqr_tp")
147+
COMMON_DTYPES = [torch.bfloat16]
30148

149+
@common_utils.parametrize("dtype", COMMON_DTYPES)
150+
@with_comms
151+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
152+
def test_tp(self, dtype):
153+
return self._test_tp(dtype)
154+
155+
common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel)
156+
common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel)
31157
if __name__ == "__main__":
32158
run_tests()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,9 +1062,12 @@ def __init__(
10621062

10631063
def _apply_fn_to_data(self, fn):
10641064
""" Applys a fn to all tensor components stored on this class"""
1065-
fn(self.float8_data)
1066-
fn(self.scale)
1067-
return self
1065+
return self.__class__(
1066+
fn(self.float8_data),
1067+
fn(self.scale),
1068+
self.transposed,
1069+
self._layout,
1070+
)
10681071

10691072
def to(self, *args, **kwargs):
10701073
kwargs = self._get_to_kwargs(*args, **kwargs)
@@ -1109,19 +1112,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
11091112
if dim == 0:
11101113
#TODO: scale replecation should be dependent on block size
11111114
if self.scale.ndim == 1:
1112-
print("slice for dim 0, scale is 1")
11131115
return return_and_correct_aliasing(
11141116
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
11151117
)
1116-
else:
1117-
print("slice for dim 0, scale != 1")
1118+
elif self.scale.ndim == 0:
11181119
return return_and_correct_aliasing(
11191120
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
11201121
)
1122+
else:
1123+
raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported")
11211124
elif dim == 1:
1122-
print("slice for dim 1")
11231125
return return_and_correct_aliasing(
1124-
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
1126+
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout)
11251127
)
11261128
else:
11271129
raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported")
@@ -1653,15 +1655,6 @@ def _linear_fp8_act_fp8_weight_impl(
16531655

16541656
# Preprocess data
16551657
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
1656-
1657-
print(f"out_shape: {out_shape}")
1658-
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
1659-
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")
1660-
1661-
1662-
print(f"out_shape: {out_shape}")
1663-
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
1664-
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")
16651658

16661659
# Perform the computation
16671660
return addmm_float8_unwrapped_inference(
@@ -1877,17 +1870,12 @@ def _(func, types, args, kwargs):
18771870
end = self.shape[dim]
18781871
shape = list(self.shape)
18791872
shape[dim] = end - start
1880-
print(f"Shape: {self.shape} -> {shape}")
1881-
print(f"Block size: {self.block_size} -> {self.block_size}")
1882-
print(f"end: {end}, start: {start}")
18831873
block_size = self.block_size
18841874
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
18851875
# with slice, some shape dimension might be smaller than block_size dimension, so
18861876
# we need to make sure there is no overflow
18871877
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
18881878
new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
1889-
print(f"slice (Outer tensor shape): {self.shape} -> {new.shape}")
1890-
print(f"slice (Inner data shape): {self.tensor_impl.float8_data.shape} -> {new.tensor_impl.float8_data.shape}")
18911879
return return_and_correct_aliasing(func, args, kwargs, new)
18921880

18931881
# this is needed for DTensor.from_local() and for flattening tensor

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def _(func, types, args, kwargs):
124124
return func(bias, aqt, original_weight_tensor)
125125
else:
126126
# aten.mm.default
127-
print('Args: ', args[0].shape, args[1].shape, type(args[0]), type(args[1]))
128127
assert args[0].shape[-1] == args[1].shape[0], (
129128
f"need mat1 shape: {args[0].shape} final dim"
130129
f"to match mat2 shape: {args[1].shape} first dim"
@@ -168,24 +167,17 @@ def _(func, types, args, kwargs):
168167

169168
@implements(aten.slice.Tensor)
170169
def _(func, types, args, kwargs):
171-
print('Input quant func: ', args[0].input_quant_func)
172-
x = return_and_correct_aliasing(
170+
return return_and_correct_aliasing(
173171
func, args, kwargs, LinearActivationQuantizedTensor(
174172
func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
175173
)
176-
print(f'Linear act Post slice: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
177-
return x
178174

179175
# this is needed for DTensor.from_local() and for flattening tensor
180176
@implements(aten.view.default)
181177
def _(func, types, args, kwargs):
182-
print('Linear view args:', args[1:])
183-
print('Device: ', args[0].original_weight_tensor.device)
184-
x= return_and_correct_aliasing(
178+
return return_and_correct_aliasing(
185179
func, args, kwargs, LinearActivationQuantizedTensor(func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
186180
)
187-
print(f'Linear act Post view: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
188-
return x
189181

190182
to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float
191183

torchao/testing/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
301301
proj_up = M(1024, 2048).to(device).to(dtype)
302302
proj_dn = M(2048, 1024).to(device).to(dtype)
303303
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
304-
print('Run y')
305304
y = proj_dn(proj_up(example_input))
306305

307306
# Quantize the model
308307
up_quant = self.quantize(proj_up)
309308
dn_quant = self.quantize(proj_dn)
310309
y_q = dn_quant(up_quant(example_input))
311-
310+
312311
mesh = self.build_device_mesh()
313312
mesh.device_type = "cuda"
314313

0 commit comments

Comments
 (0)