From 58fe60dd408e79b6d6469b40cdaf21b7da530c48 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 01/14] float8 dynamic autoquant --- test/integration/test_integration.py | 9 +++ torchao/kernel/intmm.py | 9 ++- torchao/quantization/autoquant.py | 88 +++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 4 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 8e047985c5..5446609004 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -73,6 +73,7 @@ AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8DynamicallyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -753,6 +754,14 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype): AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_dynamic_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 81e7b19b15..8851432393 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -37,8 +37,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) # error checking for cublas path assert ( mat2.device == input.device @@ -53,7 +57,6 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) - if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 089add1d87..19780088ee 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -221,7 +221,6 @@ def do_autoquant_bench(op, *args, **kwargs): stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) @@ -492,6 +491,92 @@ def from_float(cls, weight): block_size = (1, weight.shape[1]) return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) +class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): + """ + AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight + """ + @classmethod + def from_float(cls, weight): + + # avoid circular dep + from torchao.dtypes import to_affine_quantized_floatx + # weight settings + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.float8_e4m3fn + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_target_dtype = torch.float8_e4m3fn + layout_type = Float8LayoutType() + input_quant_func = lambda x: to_affine_quantized_floatx( + input_float=x, + block_size=get_per_token_block_size(x), + target_dtype=input_target_dtype, + layout_type=layout_type + ) + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + layout_type=layout_type + ) + weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) + return weight + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + """ + Tests and benchmarks the autoquantization process with special handling for interpolate mode. + + Args: + act_mat (torch.Tensor): The activation matrix. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor or None): The bias tensor. + best_time (float): The best time to beat for the quantization process. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + + Returns: + float: The benchmarked time for the autoquantization process. + """ + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 and 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] + w_qtensor = cls.from_float(weight) + x_vals_float8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_float8, x_scales, w_vals_float8: + safe_int_mm(x_vals_float8, w_vals_float8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") + with torch.no_grad(): + w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t() + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8) + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul>=best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + max_float_const_win = (best_time-res_matmul)/(res-res_matmul) + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_float_const_win:0.2f}") + return res_f + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -511,6 +596,7 @@ def from_float(cls, weight): OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8DynamicallyQuantizedLinearWeight, ] From f6465196a2dcf379d1c66cc57a4505b9082f5121 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 02/14] float8 dynamic autoquant --- torchao/quantization/autoquant.py | 2 +- torchao/quantization/utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 19780088ee..bde7028f23 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -553,7 +553,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) x_vals_float8, x_scales = quantize_activation_per_token_absmax( - act_mat.reshape(-1, act_mat.shape[-1]) + act_mat.reshape(-1, act_mat.shape[-1]), dtype=torch.float8_e4m3fn ) quantized_matmul = ( lambda x_vals_float8, x_scales, w_vals_float8: diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0df6174d0f..a9632aa90f 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -139,13 +139,12 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]: # taken from # https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 # and slightly modified -def quantize_activation_per_token_absmax(t): +def quantize_activation_per_token_absmax(t, dtype=torch.int8): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] mapping_type = MappingType.SYMMETRIC block_size = list(t.shape) for i in range(len(block_size) - 1): block_size[i] = 1 - dtype = torch.int8 eps = 1e-5 # Note: the original smoothquant does not clamp to qmin/qmax here, # but some of the tests with bfloat16 ended up with a flipped sign From bfe1eee78e31967419315cdabb3b768ddde7748b Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 03/14] float8 dynamic autoquant --- torchao/quantization/autoquant.py | 54 +++---------------------------- torchao/quantization/utils.py | 2 +- 2 files changed, 6 insertions(+), 50 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index bde7028f23..2a17cc1b77 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -17,6 +17,7 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.float8.inference import addmm_float8_unwrapped_inference import torch.nn.functional as F @@ -518,65 +519,20 @@ def get_per_token_block_size(x): input_float=x, block_size=get_per_token_block_size(x), target_dtype=input_target_dtype, - layout_type=layout_type + layout_type=layout_type, + scale_dtype=torch.float32, ) block_size = get_weight_block_size(weight) weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, target_dtype=target_dtype, - layout_type=layout_type + layout_type=layout_type, + scale_dtype=torch.float32, ) weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) return weight - @classmethod - def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): - """ - Tests and benchmarks the autoquantization process with special handling for interpolate mode. - - Args: - act_mat (torch.Tensor): The activation matrix. - weight (torch.Tensor): The weight tensor. - bias (torch.Tensor or None): The bias tensor. - best_time (float): The best time to beat for the quantization process. - mode (list, optional): A list containing mode settings for quantization. The first element is the mode type - (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. - - Returns: - float: The benchmarked time for the autoquantization process. - """ - if not _is_interpolate_mode(mode): - return super()._autoquant_test(act_mat, weight, bias, best_time, mode) - - # SAM best is between .8 and 1, SDXL also performs best in this range - INTERPOLATION_CONSTANT = mode[1] - w_qtensor = cls.from_float(weight) - x_vals_float8, x_scales = quantize_activation_per_token_absmax( - act_mat.reshape(-1, act_mat.shape[-1]), dtype=torch.float8_e4m3fn - ) - quantized_matmul = ( - lambda x_vals_float8, x_scales, w_vals_float8: - safe_int_mm(x_vals_float8, w_vals_float8) * x_scales - ) - q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") - with torch.no_grad(): - w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t() - res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8) - print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") - - # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op - if res_matmul>=best_time: - return res_matmul - - # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT - to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) - res = super()._autoquant_test(act_mat, weight, bias, to_beat) - max_float_const_win = (best_time-res_matmul)/(res-res_matmul) - res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul - print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_float_const_win:0.2f}") - return res_f - # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index a9632aa90f..3cb531d196 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -139,7 +139,7 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]: # taken from # https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 # and slightly modified -def quantize_activation_per_token_absmax(t, dtype=torch.int8): +def quantize_activation_per_token_absmax(t): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] mapping_type = MappingType.SYMMETRIC block_size = list(t.shape) From c2aedfddede1eaf6388f401acf84e6253857f190 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 04/14] float8 dynamic autoquant --- torchao/quantization/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3cb531d196..0df6174d0f 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -145,6 +145,7 @@ def quantize_activation_per_token_absmax(t): block_size = list(t.shape) for i in range(len(block_size) - 1): block_size[i] = 1 + dtype = torch.int8 eps = 1e-5 # Note: the original smoothquant does not clamp to qmin/qmax here, # but some of the tests with bfloat16 ended up with a flipped sign From f4392567c128f3d24b7829a6996439daa29900a5 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 05/14] float8 dynamic autoquant --- torchao/quantization/autoquant.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 2a17cc1b77..d3e7a9b0bf 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -17,7 +17,8 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.float8.inference import addmm_float8_unwrapped_inference +from torchao.quantization.quant_api import _input_activation_quant_func_fp8 +from torchao.quantization.observer import PerRow import torch.nn.functional as F @@ -515,12 +516,10 @@ def get_per_token_block_size(x): input_target_dtype = torch.float8_e4m3fn layout_type = Float8LayoutType() - input_quant_func = lambda x: to_affine_quantized_floatx( - input_float=x, - block_size=get_per_token_block_size(x), - target_dtype=input_target_dtype, - layout_type=layout_type, - scale_dtype=torch.float32, + input_quant_func = lambda x: _input_activation_quant_func_fp8( + x=x + activation_granularity=PerRow, + activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) weight = to_affine_quantized_floatx( From 432bef403dc544dafce24e18cb15b1424a43ccd4 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 06/14] float8 dynamic autoquant --- torchao/kernel/intmm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 8851432393..c2aa36bfe9 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -37,12 +37,8 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - try: - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - except Exception: - # fallback path, would run on H100 for float8 dtypes - # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' - return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + # error checking for cublas path assert ( mat2.device == input.device From 0897f14b33726a43fe5d3e7aa3c43647e18f3ed0 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 07/14] float8 dynamic autoquant --- torchao/kernel/intmm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index c2aa36bfe9..7d076a6e81 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -53,6 +53,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) + if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( From 81bbeb19a17522c0f7f0de2c97172a85f94b621c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 08/14] float8 dynamic autoquant --- torchao/quantization/autoquant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d3e7a9b0bf..26ca413590 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -517,7 +517,7 @@ def get_per_token_block_size(x): input_target_dtype = torch.float8_e4m3fn layout_type = Float8LayoutType() input_quant_func = lambda x: _input_activation_quant_func_fp8( - x=x + x=x, activation_granularity=PerRow, activation_dtype=input_target_dtype, ) From 257d8cfaba5fc921155abe4501eb2130df28777c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 09/14] float8 dynamic autoquant --- test/integration/test_integration.py | 3 +++ torchao/_models/llama/generate.py | 4 +++- torchao/quantization/autoquant.py | 9 +++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5446609004..ebfd3eb405 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -681,6 +681,7 @@ def _test_lin_weight_subclass_impl( m, k, n = test_shape x = torch.randn(m, k, device=test_device, dtype=test_dtype) lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype) + lin.bias.requires_grad = False ref_f = lin(x) lin.weight = torch.nn.Parameter( @@ -758,6 +759,8 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") @unittest.skipIf(not is_H100, "Need H100 to run") def test_aq_float8_dynamic_quant_subclass(self, device, dtype): + if dtype != torch.bfloat16: + self.skipTest("Fails for {dtype}") self._test_lin_weight_subclass_impl( AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5fb905dbf9..19e42e7cd5 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -246,6 +246,8 @@ def main( if "autoquant" in quantization: if "autoquant-int4" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + elif "autoquant-float8" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) else: model = autoquant(model, manual=True) @@ -415,7 +417,7 @@ def callback(x): parser.add_argument('-q', '--quantization', type=str, help=( 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, uintx--, uintx---hqq, sparse-marlin' + +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin' ) ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 26ca413590..38e4826ec7 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -17,8 +17,8 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.quantization.quant_api import _input_activation_quant_func_fp8 -from torchao.quantization.observer import PerRow +from torchao.quantization.observer import PerAxis, PerTensor, PerRow +from torchao.float8.inference import Float8MMConfig import torch.nn.functional as F @@ -502,6 +502,7 @@ def from_float(cls, weight): # avoid circular dep from torchao.dtypes import to_affine_quantized_floatx + from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings def get_weight_block_size(x): return (1, x.shape[1]) @@ -515,10 +516,10 @@ def get_per_token_block_size(x): return block_size input_target_dtype = torch.float8_e4m3fn - layout_type = Float8LayoutType() + layout_type = Float8LayoutType(mm_config=Float8MMConfig(use_fast_accum=True)) input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, - activation_granularity=PerRow, + activation_granularity=PerRow(), activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) From c762e777ca996a90d4063a0f70d7954dc4ff731f Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 10/14] float8 dynamic autoquant --- test/integration/test_integration.py | 44 ++++++++++++++-------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index ebfd3eb405..1001807487 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -678,28 +678,28 @@ def _test_lin_weight_subclass_impl( ): if not "cuda" in test_device: self.skipTest("test requires cuda") - m, k, n = test_shape - x = torch.randn(m, k, device=test_device, dtype=test_dtype) - lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype) - lin.bias.requires_grad = False - ref_f = lin(x) - - lin.weight = torch.nn.Parameter( - test_subclass_from_float(lin.weight), requires_grad=False - ) - test = lin(x) - self.assertGreater( - SQNR(ref_f, test), - min_sqnr, - f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}" - ) - lin_comp = torch.compile(lin, mode='max-autotune') - test_comp = lin_comp(x) - self.assertGreater( - SQNR(ref_f, test_comp), - min_sqnr, - f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}" - ) + with torch.no_grad(): + m, k, n = test_shape + x = torch.randn(m, k, device=test_device, dtype=test_dtype) + lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype) + ref_f = lin(x) + + lin.weight = torch.nn.Parameter( + test_subclass_from_float(lin.weight), requires_grad=False + ) + test = lin(x) + self.assertGreater( + SQNR(ref_f, test), + min_sqnr, + f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}" + ) + lin_comp = torch.compile(lin, mode='max-autotune') + test_comp = lin_comp(x) + self.assertGreater( + SQNR(ref_f, test_comp), + min_sqnr, + f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}" + ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen") From 6c6e46bb204c88f927866c588c65603b299744ae Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 11/14] float8 dynamic autoquant --- torchao/quantization/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 05c55b255d..9eb312dd6c 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -23,6 +23,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "OTHER_AUTOQUANT_CLASS_LIST", "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", From 35903de3d94f3d566bf401037eb219b656c9de43 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 12/14] float8 dynamic autoquant --- torchao/quantization/autoquant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 38e4826ec7..368e7a219a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -27,6 +27,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "OTHER_AUTOQUANT_CLASS_LIST", ] From 17a1c56ab81933d3598db30569240032fcac80ae Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 13/14] float8 dynamic autoquant --- test/integration/test_integration.py | 2 +- torchao/quantization/autoquant.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 1001807487..0b69882bba 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -762,7 +762,7 @@ def test_aq_float8_dynamic_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Fails for {dtype}") self._test_lin_weight_subclass_impl( - AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 368e7a219a..a5568c4e17 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -494,10 +494,11 @@ def from_float(cls, weight): block_size = (1, weight.shape[1]) return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) -class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): +class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): """ - AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight + AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling """ + activation_granularity: str = PerRow() @classmethod def from_float(cls, weight): @@ -520,7 +521,7 @@ def get_per_token_block_size(x): layout_type = Float8LayoutType(mm_config=Float8MMConfig(use_fast_accum=True)) input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, - activation_granularity=PerRow(), + activation_granularity=cls.activation_granularity, activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) @@ -531,7 +532,7 @@ def get_per_token_block_size(x): layout_type=layout_type, scale_dtype=torch.float32, ) - weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) + weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) return weight @@ -553,7 +554,7 @@ def get_per_token_block_size(x): OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, - AQFloat8DynamicallyQuantizedLinearWeight, + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, ] @@ -681,7 +682,7 @@ def autoquant( if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST: + if qtensor_class_list is OTHER_AUTOQUANT_CLASS_LIST: assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9" # perform initial swap from linear weights From ae18023de1966bd40cea644873804808908f70da Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH 14/14] float8 dynamic autoquant --- test/integration/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 0b69882bba..7246bf02f0 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -73,7 +73,7 @@ AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, - AQFloat8DynamicallyQuantizedLinearWeight, + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os