2525 "autoquant" ,
2626 "DEFAULT_AUTOQUANT_CLASS_LIST" ,
2727 "DEFAULT_INT4_AUTOQUANT_CLASS_LIST" ,
28+ "OTHER_AUTOQUANT_CLASS_LIST" ,
2829]
2930
3031
@@ -492,6 +493,105 @@ def from_float(cls, weight):
492493 block_size = (1 , weight .shape [1 ])
493494 return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494495
496+ class AQFloat8DynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
497+ """
498+ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
499+ """
500+ @classmethod
501+ def from_float (cls , weight ):
502+ # TODO test if this is valid
503+ # in_features = weight.shape[1]
504+ # int8 dynamic quantization only has benefit when in_feature > 16
505+ # if in_features <= 16:
506+ # return weight
507+
508+ # avoid circular dep
509+ from torchao .dtypes import to_affine_quantized_floatx
510+ # weight settings
511+ mapping_type = MappingType .SYMMETRIC
512+ def get_weight_block_size (x ):
513+ return (1 , x .shape [1 ])
514+ target_dtype = torch .float8_e4m3fn
515+ eps = torch .finfo (torch .float32 ).eps
516+ zero_point_dtype = torch .float32
517+
518+ # input settings
519+ def get_per_token_block_size (x ):
520+ block_size = list (x .shape )
521+ for i in range (len (block_size )- 1 ):
522+ block_size [i ] = 1
523+ return block_size
524+
525+ input_mapping_type = MappingType .SYMMETRIC
526+ input_target_dtype = torch .float8_e4m3fn
527+ input_eps = 1e-5
528+ input_quant_min = torch .finfo (input_target_dtype ).min
529+ input_quant_max = torch .finfo (input_target_dtype ).max
530+ layout_type = Float8LayoutType ()
531+ input_quant_func = to_affine_quantized_floatx (
532+ input_float = x ,
533+ block_size = get_per_token_block_size (x ),
534+ target_dtype = input_target_dtype ,
535+ layout_type = layout_type
536+ )
537+
538+ block_size = get_weight_block_size (weight )
539+ weight = to_affine_quantized_floatx (
540+ input_float = weight ,
541+ block_size = block_size ,
542+ target_dtype = target_dtype ,
543+ layout_type = layout_type
544+ )
545+ weight = super (AQFloat8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
546+ return weight
547+
548+ # @classmethod
549+ # def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
550+ # """
551+ # Tests and benchmarks the autoquantization process with special handling for interpolate mode.
552+
553+ # Args:
554+ # act_mat (torch.Tensor): The activation matrix.
555+ # weight (torch.Tensor): The weight tensor.
556+ # bias (torch.Tensor or None): The bias tensor.
557+ # best_time (float): The best time to beat for the quantization process.
558+ # mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
559+ # (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
560+
561+ # Returns:
562+ # float: The benchmarked time for the autoquantization process.
563+ # """
564+ # if not _is_interpolate_mode(mode):
565+ # return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
566+
567+ # # SAM best is between .8 and 1, SDXL also performs best in this range
568+ # INTERPOLATION_CONSTANT = mode[1]
569+ # w_qtensor = cls.from_float(weight)
570+ # x_vals_int8, x_scales = quantize_activation_per_token_absmax(
571+ # act_mat.reshape(-1, act_mat.shape[-1])
572+ # )
573+ # quantized_matmul = (
574+ # lambda x_vals_int8, x_scales, w_vals_int8:
575+ # safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
576+ # )
577+ # q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
578+ # with torch.no_grad():
579+ # w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
580+ # res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
581+ # print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
582+
583+ # # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
584+ # if res_matmul>=best_time:
585+ # return res_matmul
586+
587+ # # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
588+ # to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
589+ # res = super()._autoquant_test(act_mat, weight, bias, to_beat)
590+ # max_int_const_win = (best_time-res_matmul)/(res-res_matmul)
591+ # res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
592+ # print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
593+ # return res_f
594+
495595
496596# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
497597DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -511,6 +611,7 @@ def from_float(cls, weight):
511611
512612OTHER_AUTOQUANT_CLASS_LIST = [
513613 AQFloat8WeightOnlyQuantizedLinearWeight ,
614+ AQFloat8DynamicallyQuantizedLinearWeight ,
514615]
515616
516617
0 commit comments