Skip to content

Commit 6314d88

Browse files
committed
Float8 tensor parallel for aqt_dynamic_act_weight
1 parent e7b33bc commit 6314d88

File tree

4 files changed

+71
-7
lines changed

4 files changed

+71
-7
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
33
from torch.testing._internal.common_utils import run_tests
4-
from torchao.quantization import int8_weight_only, float8_weight_only
4+
from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight
5+
from torchao.quantization.observer import PerRow, PerTensor
56

67
class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
78
QUANT_METHOD_FN = staticmethod(int8_weight_only)
@@ -13,5 +14,19 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
1314
QUANT_METHOD_FN = staticmethod(float8_weight_only)
1415
copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp")
1516

17+
# Run only on H100
18+
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
19+
class TestFloat8dqTensorAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
20+
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
21+
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
22+
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqTensorAffineQuantizedTensorParallel, "fp8dqt_tp")
23+
24+
# Run only on H100
25+
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
26+
class TestFloat8dqRowAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
27+
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
28+
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
29+
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqRowAffineQuantizedTensorParallel, "fp8dqr_tp")
30+
1631
if __name__ == "__main__":
1732
run_tests()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,12 +1107,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
11071107
elif func is aten.slice.Tensor:
11081108
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
11091109
if dim == 0:
1110+
#TODO: scale replecation should be dependent on block size
1111+
if self.scale.ndim == 1:
1112+
print("slice for dim 0, scale is 1")
1113+
return return_and_correct_aliasing(
1114+
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
1115+
)
1116+
else:
1117+
print("slice for dim 0, scale != 1")
1118+
return return_and_correct_aliasing(
1119+
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
1120+
)
1121+
elif dim == 1:
1122+
print("slice for dim 1")
11101123
return return_and_correct_aliasing(
1111-
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
1124+
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
11121125
)
1113-
elif dim == 1:
1114-
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
1115-
return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
11161126
else:
11171127
raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported")
11181128
else:
@@ -1644,6 +1654,11 @@ def _linear_fp8_act_fp8_weight_impl(
16441654
# Preprocess data
16451655
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
16461656

1657+
1658+
print(f"out_shape: {out_shape}")
1659+
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
1660+
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")
1661+
16471662
# Perform the computation
16481663
return addmm_float8_unwrapped_inference(
16491664
inpt_data,
@@ -1858,12 +1873,17 @@ def _(func, types, args, kwargs):
18581873
end = self.shape[dim]
18591874
shape = list(self.shape)
18601875
shape[dim] = end - start
1876+
print(f"Shape: {self.shape} -> {shape}")
1877+
print(f"Block size: {self.block_size} -> {self.block_size}")
1878+
print(f"end: {end}, start: {start}")
18611879
block_size = self.block_size
18621880
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
18631881
# with slice, some shape dimension might be smaller than block_size dimension, so
18641882
# we need to make sure there is no overflow
18651883
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
18661884
new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
1885+
print(f"slice (Outer tensor shape): {self.shape} -> {new.shape}")
1886+
print(f"slice (Inner data shape): {self.tensor_impl.float8_data.shape} -> {new.tensor_impl.float8_data.shape}")
18671887
return return_and_correct_aliasing(func, args, kwargs, new)
18681888

18691889
# this is needed for DTensor.from_local() and for flattening tensor

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def _(func, types, args, kwargs):
124124
return func(bias, aqt, original_weight_tensor)
125125
else:
126126
# aten.mm.default
127+
print('Args: ', args[0].shape, args[1].shape, type(args[0]), type(args[1]))
127128
assert args[0].shape[-1] == args[1].shape[0], (
128129
f"need mat1 shape: {args[0].shape} final dim"
129130
f"to match mat2 shape: {args[1].shape} first dim"
@@ -165,6 +166,27 @@ def _(func, types, args, kwargs):
165166
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
166167
)
167168

169+
@implements(aten.slice.Tensor)
170+
def _(func, types, args, kwargs):
171+
print('Input quant func: ', args[0].input_quant_func)
172+
x = return_and_correct_aliasing(
173+
func, args, kwargs, LinearActivationQuantizedTensor(
174+
func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
175+
)
176+
print(f'Linear act Post slice: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
177+
return x
178+
179+
# this is needed for DTensor.from_local() and for flattening tensor
180+
@implements(aten.view.default)
181+
def _(func, types, args, kwargs):
182+
print('Linear view args:', args[1:])
183+
print('Device: ', args[0].original_weight_tensor.device)
184+
x= return_and_correct_aliasing(
185+
func, args, kwargs, LinearActivationQuantizedTensor(func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
186+
)
187+
print(f'Linear act Post view: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
188+
return x
189+
168190
to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float
169191

170192
if TORCH_VERSION_AT_LEAST_2_5:

torchao/testing/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
250250
m.linear.weight = torch.nn.Parameter(
251251
dtensor, requires_grad=False
252252
)
253+
print('colwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)
253254
return m
254255

255256
@staticmethod
@@ -264,11 +265,15 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
264265
rank = mesh.get_local_rank()
265266
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
266267
# Construct DTensor from local shard
267-
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
268+
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
269+
print(f'dtensor shape: {dtensor.shape}')
270+
print(f'Other dtensor values: {local_shard.original_weight_tensor.tensor_impl.float8_data.shape}, {mesh}, {[Shard(1)]}')
268271
# Replace parameter in module
269272
m.linear.weight = torch.nn.Parameter(
270273
dtensor, requires_grad=False
271274
)
275+
print('rowwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)
276+
272277
return m
273278

274279
def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
@@ -302,11 +307,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
302307
proj_dn = M(2048, 1024).to(device).to(dtype)
303308
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
304309
y = proj_dn(proj_up(example_input))
305-
310+
print('Run before y')
306311
# Quantize the model
307312
up_quant = self.quantize(proj_up)
308313
dn_quant = self.quantize(proj_dn)
314+
print('Run before y_q')
309315
y_q = dn_quant(up_quant(example_input))
316+
print('Executed y_q')
310317

311318
mesh = self.build_device_mesh()
312319
mesh.device_type = "cuda"

0 commit comments

Comments
 (0)