Skip to content

Commit a2d77ce

Browse files
committed
Float8 dynamic autoquant
1 parent 1198fa5 commit a2d77ce

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

test/integration/test_integration.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
7575
AQFloat8WeightOnlyQuantizedLinearWeight,
76+
AQFloat8DynamicallyQuantizedLinearWeight,
7677
)
7778
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7879
import os
@@ -753,6 +754,13 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
753754
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
754755
)
755756

757+
@parameterized.expand(COMMON_DEVICE_DTYPE)
758+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
759+
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
760+
self._test_lin_weight_subclass_impl(
761+
AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
762+
)
763+
756764
@parameterized.expand(COMMON_DEVICE_DTYPE)
757765
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
758766
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/quantization/autoquant.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"autoquant",
2626
"DEFAULT_AUTOQUANT_CLASS_LIST",
2727
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
28+
"OTHER_AUTOQUANT_CLASS_LIST",
2829
]
2930

3031

@@ -492,6 +493,105 @@ def from_float(cls, weight):
492493
block_size = (1, weight.shape[1])
493494
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())
494495

496+
class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
497+
"""
498+
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
499+
"""
500+
@classmethod
501+
def from_float(cls, weight):
502+
# TODO test if this is valid
503+
# in_features = weight.shape[1]
504+
# int8 dynamic quantization only has benefit when in_feature > 16
505+
# if in_features <= 16:
506+
# return weight
507+
508+
# avoid circular dep
509+
from torchao.dtypes import to_affine_quantized_floatx
510+
# weight settings
511+
mapping_type = MappingType.SYMMETRIC
512+
def get_weight_block_size(x):
513+
return (1, x.shape[1])
514+
target_dtype = torch.float8_e4m3fn
515+
eps = torch.finfo(torch.float32).eps
516+
zero_point_dtype = torch.float32
517+
518+
# input settings
519+
def get_per_token_block_size(x):
520+
block_size = list(x.shape)
521+
for i in range(len(block_size)-1):
522+
block_size[i] = 1
523+
return block_size
524+
525+
input_mapping_type = MappingType.SYMMETRIC
526+
input_target_dtype = torch.float8_e4m3fn
527+
input_eps = 1e-5
528+
input_quant_min = torch.finfo(input_target_dtype).min
529+
input_quant_max = torch.finfo(input_target_dtype).max
530+
layout_type = Float8LayoutType()
531+
input_quant_func = to_affine_quantized_floatx(
532+
input_float=x,
533+
block_size=get_per_token_block_size(x),
534+
target_dtype=input_target_dtype,
535+
layout_type=layout_type
536+
)
537+
538+
block_size = get_weight_block_size(weight)
539+
weight = to_affine_quantized_floatx(
540+
input_float=weight,
541+
block_size=block_size,
542+
target_dtype=target_dtype,
543+
layout_type=layout_type
544+
)
545+
weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
546+
return weight
547+
548+
# @classmethod
549+
# def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
550+
# """
551+
# Tests and benchmarks the autoquantization process with special handling for interpolate mode.
552+
553+
# Args:
554+
# act_mat (torch.Tensor): The activation matrix.
555+
# weight (torch.Tensor): The weight tensor.
556+
# bias (torch.Tensor or None): The bias tensor.
557+
# best_time (float): The best time to beat for the quantization process.
558+
# mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
559+
# (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
560+
561+
# Returns:
562+
# float: The benchmarked time for the autoquantization process.
563+
# """
564+
# if not _is_interpolate_mode(mode):
565+
# return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
566+
567+
# # SAM best is between .8 and 1, SDXL also performs best in this range
568+
# INTERPOLATION_CONSTANT = mode[1]
569+
# w_qtensor = cls.from_float(weight)
570+
# x_vals_int8, x_scales = quantize_activation_per_token_absmax(
571+
# act_mat.reshape(-1, act_mat.shape[-1])
572+
# )
573+
# quantized_matmul = (
574+
# lambda x_vals_int8, x_scales, w_vals_int8:
575+
# safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
576+
# )
577+
# q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
578+
# with torch.no_grad():
579+
# w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
580+
# res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
581+
# print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
582+
583+
# # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
584+
# if res_matmul>=best_time:
585+
# return res_matmul
586+
587+
# # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
588+
# to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
589+
# res = super()._autoquant_test(act_mat, weight, bias, to_beat)
590+
# max_int_const_win = (best_time-res_matmul)/(res-res_matmul)
591+
# res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
592+
# print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
593+
# return res_f
594+
495595

496596
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
497597
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -511,6 +611,7 @@ def from_float(cls, weight):
511611

512612
OTHER_AUTOQUANT_CLASS_LIST = [
513613
AQFloat8WeightOnlyQuantizedLinearWeight,
614+
AQFloat8DynamicallyQuantizedLinearWeight,
514615
]
515616

516617

0 commit comments

Comments
 (0)