@@ -492,6 +492,93 @@ def from_float(cls, weight):
492492 block_size = (1 , weight .shape [1 ])
493493 return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494494
495+ class AQFloat8DynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
496+ """
497+ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
498+ """
499+ @classmethod
500+ def from_float (cls , weight ):
501+
502+ # avoid circular dep
503+ from torchao .dtypes import to_affine_quantized_floatx
504+ # weight settings
505+ def get_weight_block_size (x ):
506+ return (1 , x .shape [1 ])
507+ target_dtype = torch .float8_e4m3fn
508+
509+ # input settings
510+ def get_per_token_block_size (x ):
511+ block_size = list (x .shape )
512+ for i in range (len (block_size )- 1 ):
513+ block_size [i ] = 1
514+ return block_size
515+
516+ input_target_dtype = torch .float8_e4m3fn
517+ layout_type = Float8LayoutType ()
518+ input_quant_func = lambda x : to_affine_quantized_floatx (
519+ input_float = x ,
520+ block_size = get_per_token_block_size (x ),
521+ target_dtype = input_target_dtype ,
522+ layout_type = layout_type
523+ )
524+
525+ block_size = get_weight_block_size (weight )
526+ weight = to_affine_quantized_floatx (
527+ input_float = weight ,
528+ block_size = block_size ,
529+ target_dtype = target_dtype ,
530+ layout_type = layout_type
531+ )
532+ weight = super (AQFloat8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
533+ return weight
534+
535+ @classmethod
536+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
537+ """
538+ Tests and benchmarks the autoquantization process with special handling for interpolate mode.
539+
540+ Args:
541+ act_mat (torch.Tensor): The activation matrix.
542+ weight (torch.Tensor): The weight tensor.
543+ bias (torch.Tensor or None): The bias tensor.
544+ best_time (float): The best time to beat for the quantization process.
545+ mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
546+ (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
547+
548+ Returns:
549+ float: The benchmarked time for the autoquantization process.
550+ """
551+ if not _is_interpolate_mode (mode ):
552+ return super ()._autoquant_test (act_mat , weight , bias , best_time , mode )
553+
554+ # SAM best is between .8 and 1, SDXL also performs best in this range
555+ INTERPOLATION_CONSTANT = mode [1 ]
556+ w_qtensor = cls .from_float (weight )
557+ x_vals_float8 , x_scales = quantize_activation_per_token_absmax (
558+ act_mat .reshape (- 1 , act_mat .shape [- 1 ])
559+ )
560+ quantized_matmul = (
561+ lambda x_vals_float8 , x_scales , w_vals_float8 :
562+ safe_int_mm (x_vals_float8 , w_vals_float8 ) * x_scales
563+ )
564+ q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs" )
565+ with torch .no_grad ():
566+ w_vals_float8 = w_qtensor .original_weight_tensor .layout_tensor .float8_data .contiguous ().t ()
567+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_float8 , x_scales .reshape (- 1 ,1 ), w_vals_float8 )
568+ print (f">>time: { res_matmul :0.3f} ms for { cls } matmul, to_beat: { best_time :0.3f} ms" )
569+
570+ # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
571+ if res_matmul >= best_time :
572+ return res_matmul
573+
574+ # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
575+ to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT )* (best_time - res_matmul )
576+ res = super ()._autoquant_test (act_mat , weight , bias , to_beat )
577+ max_float_const_win = (best_time - res_matmul )/ (res - res_matmul )
578+ res_f = INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul
579+ print (f">>time: { res_f :0.3f} ms for { cls } interpolated, breakeven constant: { max_float_const_win :0.2f} " )
580+ return res_f
581+
495582
496583# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
497584DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -511,6 +598,7 @@ def from_float(cls, weight):
511598
512599OTHER_AUTOQUANT_CLASS_LIST = [
513600 AQFloat8WeightOnlyQuantizedLinearWeight ,
601+ AQFloat8DynamicallyQuantizedLinearWeight ,
514602]
515603
516604
0 commit comments