99 Int8WeightOnlyQuantizedLinearWeight ,
1010 QuantizedLinearWeightBase ,
1111)
12- from torchao .dtypes import AffineQuantizedTensor , PlainLayoutType , TensorCoreTiledLayoutType
12+ from torchao .dtypes import AffineQuantizedTensor , PlainLayoutType , TensorCoreTiledLayoutType , Float8LayoutType
1313from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
1414from torch .utils ._python_dispatch import return_and_correct_aliasing
1515from .quant_primitives import (
@@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
477477 def from_float (cls , weight ):
478478 return weight
479479
480+ class AQFloat8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
481+ """
482+ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
483+ """
484+ target_dtype : torch .dtype = torch .float8_e4m3fn
485+
486+ @staticmethod
487+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
488+ return torch .nn .functional .linear (act_mat , w_qtensor .dequantize (), bias )
489+
490+ @classmethod
491+ def from_float (cls , weight ):
492+ block_size = (1 , weight .shape [1 ])
493+ return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494+
495+
480496# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
481497DEFAULT_AUTOQUANT_CLASS_LIST = [
482498 AQFloatLinearWeight ,
@@ -493,6 +509,11 @@ def from_float(cls, weight):
493509 AQInt4G64WeightOnlyQuantizedLinearWeight
494510]
495511
512+ OTHER_AUTOQUANT_CLASS_LIST = [
513+ AQFloat8WeightOnlyQuantizedLinearWeight ,
514+ ]
515+
516+
496517def _change_linears_to_autoquantizable (model , ** kwargs ):
497518 """
498519 Converts all linear weight tensors to the
@@ -617,6 +638,8 @@ def autoquant(
617638 if set_inductor_config :
618639 torchao .quantization .utils .recommended_inductor_config_setter ()
619640
641+ if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST :
642+ assert torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 ), "float8 requires CUDA arch >= 8.9"
620643
621644 # perform initial swap from linear weights
622645 # to AutoQuantizableLinearWeight
0 commit comments