66from torchao .dtypes import (
77 AffineQuantizedTensor ,
88 Float8Layout ,
9+ MarlinSparseLayout ,
910 PlainLayout ,
11+ SemiSparseLayout ,
1012 TensorCoreTiledLayout ,
1113)
14+ from torchao .dtypes .utils import Layout
1215from torchao .float8 .inference import Float8MMConfig
1316from torchao .kernel import safe_int_mm
1417from torchao .quantization .linear_activation_quantized_tensor import (
4649 "DEFAULT_AUTOQUANT_CLASS_LIST" ,
4750 "DEFAULT_INT4_AUTOQUANT_CLASS_LIST" ,
4851 "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST" ,
52+ "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST" ,
4953 "OTHER_AUTOQUANT_CLASS_LIST" ,
5054 "ALL_AUTOQUANT_CLASS_LIST" ,
5155]
@@ -406,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT
406410 AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
407411 """
408412
413+ layout : Layout = PlainLayout ()
414+
409415 @classmethod
410416 def from_float (cls , weight ):
411417 # TODO test if this is valid
@@ -414,6 +420,9 @@ def from_float(cls, weight):
414420 # if in_features <= 16:
415421 # return weight
416422
423+ if weight .dim () != 2 :
424+ return weight
425+
417426 # avoid circular dep
418427 from torchao .dtypes import to_affine_quantized_intx
419428
@@ -439,7 +448,7 @@ def get_per_token_block_size(x):
439448 input_eps = 1e-5
440449 input_quant_min = - 127
441450 input_quant_max = 127
442- _layout = PlainLayout ()
451+ _layout = cls . layout
443452 input_quant_func = lambda x : to_affine_quantized_intx (
444453 x ,
445454 input_mapping_type ,
@@ -526,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
526535 return res_f
527536
528537
538+ class AQInt8DynamicallyQuantizedSemiSparseLinearWeight (
539+ AQInt8DynamicallyQuantizedLinearWeight
540+ ):
541+ layout : Layout = SemiSparseLayout ()
542+
543+ @classmethod
544+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
545+ return super ()._autoquant_test (act_mat , weight , bias , best_time , None )
546+
547+
529548class AQInt8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
530549 """
531550 AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
@@ -613,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
613632 """
614633
615634 group_size : int = 32
635+ layout : Layout = TensorCoreTiledLayout (inner_k_tiles = 8 )
616636
617637 @classmethod
618638 def from_float (cls , weight ):
619639 group_size = cls .group_size
620- _layout = TensorCoreTiledLayout ( inner_k_tiles = 8 )
640+ _layout = cls . layout
621641
622642 if weight .shape [- 1 ] % group_size != 0 :
623643 return weight
644+
624645 use_hqq = True
625646 mapping_type = MappingType .ASYMMETRIC
626647 block_size = (1 , group_size )
@@ -631,6 +652,13 @@ def from_float(cls, weight):
631652 preserve_zero = False
632653 zero_point_dtype = torch .bfloat16
633654 zero_point_domain = ZeroPointDomain .FLOAT
655+
656+ if isinstance (_layout , MarlinSparseLayout ):
657+ mapping_type = MappingType .SYMMETRIC
658+ preserve_zero = True
659+ zero_point_domain = ZeroPointDomain .INT
660+ use_hqq = False
661+
634662 return super (AQInt4G32WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_intx (
635663 weight ,
636664 mapping_type ,
@@ -665,6 +693,13 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
665693 group_size : int = 256
666694
667695
696+ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight (
697+ AQInt4G32WeightOnlyQuantizedLinearWeight
698+ ):
699+ group_size : int = 128
700+ layout : Layout = MarlinSparseLayout ()
701+
702+
668703class AQDefaultLinearWeight (torch .Tensor , AQMixin ):
669704 """
670705 A class to be used in concert with AutoQuantizableLinearWeight to provide a
@@ -949,16 +984,24 @@ def get_weight_block_size(x):
949984]
950985
951986OTHER_AUTOQUANT_CLASS_LIST = [
987+ AQDefaultLinearWeight ,
952988 AQFloat8WeightOnlyQuantizedLinearWeight ,
953989 AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight ,
954990 AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight ,
955991]
956992
993+ DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [
994+ AQDefaultLinearWeight ,
995+ AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight ,
996+ AQInt8DynamicallyQuantizedSemiSparseLinearWeight ,
997+ ]
998+
957999ALL_AUTOQUANT_CLASS_LIST = list (
9581000 set (
9591001 DEFAULT_AUTOQUANT_CLASS_LIST
9601002 + DEFAULT_INT4_AUTOQUANT_CLASS_LIST
9611003 + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
1004+ + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST
9621005 )
9631006)
9641007if is_sm_at_least_89 ():
0 commit comments