3434
3535# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
3636from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa
37+ from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout , QDQLayout
38+ from torchao .experimental .quant_api import EmbeddingQuantizer
39+ from torchao .quantization .granularity import PerAxis , PerGroup
3740from torchao .quantization .quant_api import (
3841 int4_weight_only ,
3942 Int4WeightOnlyQuantizer ,
4043 Int8DynActInt4WeightQuantizer ,
44+ Int8DynamicActivationIntxWeightConfig ,
45+ MappingType ,
4146 quantize_ ,
4247)
4348from torchao .utils import unwrap_tensor_subclass
5055 state_dict_device ,
5156 use_et_backend ,
5257)
53- from torchao .experimental .packed_linear_int8_dynamic_activation_intx_weight_layout import (
54- PackedLinearInt8DynamicActivationIntxWeightLayout ,
55- )
56- from torchao .experimental .quant_api import (
57- int8_dynamic_activation_intx_weight ,
58- IntxWeightEmbeddingQuantizer ,
59- )
60- from torchao .quantization .granularity import (
61- PerGroup ,
62- PerRow ,
63- )
64- from torchao .dtypes import PlainLayout
6558
6659
6760# Flag for whether the a8wxdq quantizer is available.
@@ -87,7 +80,7 @@ def get_named_parameters(func: Callable) -> List[str]:
8780 return named_params
8881
8982def validate_args (named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None ) -> Dict [str , Any ]:
90- for key in q_kwargs .keys ():
83+ for key in list ( q_kwargs .keys () ):
9184 if key not in named_params :
9285 print (f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring." )
9386 del q_kwargs [key ]
@@ -137,29 +130,34 @@ def quantize_model(
137130 group_size = q_kwargs ["groupsize" ]
138131 bit_width = q_kwargs ["bitwidth" ]
139132 has_weight_zeros = q_kwargs ["has_weight_zeros" ]
140- granularity = PerRow () if group_size == - 1 else PerGroup (group_size )
133+ granularity = PerAxis () if group_size == - 1 else PerGroup (group_size )
141134 weight_dtype = getattr (torch , f"int{ bit_width } " )
135+ weight_mapping_type = (
136+ MappingType .ASYMMETRIC
137+ if has_weight_zeros
138+ else MappingType .SYMMETRIC
139+ )
142140
143141 try :
144142 quantize_ (
145- model ,
146- int8_dynamic_activation_intx_weight (
143+ model ,
144+ Int8DynamicActivationIntxWeightConfig (
147145 weight_dtype = weight_dtype ,
148- granularity = granularity ,
149- has_weight_zeros = has_weight_zeros ,
146+ weight_granularity = granularity ,
147+ weight_mapping_type = weight_mapping_type ,
150148 layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
151149 ),
152150 )
153151 except Exception as e :
154152 print ("Encountered error during quantization: {e}" )
155- print ("Trying with PlainLayout " )
153+ print ("Trying with QDQLayout " )
156154 quantize_ (
157- model ,
158- int8_dynamic_activation_intx_weight (
155+ model ,
156+ Int8DynamicActivationIntxWeightConfig (
159157 weight_dtype = weight_dtype ,
160- granularity = granularity ,
161- has_weight_zeros = has_weight_zeros ,
162- layout = PlainLayout (),
158+ weight_granularity = granularity ,
159+ weight_mapping_type = weight_mapping_type ,
160+ layout = QDQLayout (),
163161 ),
164162 )
165163
@@ -174,6 +172,22 @@ def quantize_model(
174172 print (f"Quantizer { quantizer } requires float32 inputs, but received { get_precision ()} . Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." )
175173 set_precision (torch .float32 )
176174
175+ group_size = q_kwargs ["groupsize" ]
176+ bit_width = q_kwargs ["bitwidth" ]
177+ has_weight_zeros = q_kwargs .get ("has_weight_zeros" , True )
178+ q_kwargs ["granularity" ] = (
179+ PerAxis () if group_size == - 1 else PerGroup (group_size )
180+ )
181+ q_kwargs ["weight_dtype" ] = getattr (torch , f"int{ bit_width } " )
182+ q_kwargs ["mapping_type" ] = (
183+ MappingType .ASYMMETRIC
184+ if has_weight_zeros
185+ else MappingType .SYMMETRIC
186+ )
187+ q_kwargs ["use_fallback" ] = False
188+ del q_kwargs ["groupsize" ]
189+ del q_kwargs ["bitwidth" ]
190+
177191 if quantizer == "linear:afpwx" and device != "mps" :
178192 raise RuntimeError ("linear:afpwx quantization can only run on mps device!" )
179193
@@ -188,7 +202,10 @@ def quantize_model(
188202 # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
189203 if "tokenizer" in named_params :
190204 q_kwargs ["tokenizer" ] = tokenizer
191- quant_handler = q (device = device , precision = precision , ** q_kwargs )
205+ if quantizer == "embedding:wx" :
206+ quant_handler = q (** q_kwargs )
207+ else :
208+ quant_handler = q (device = device , precision = precision , ** q_kwargs )
192209
193210 # quantize model
194211 model = quant_handler .quantize (model )
@@ -939,7 +956,7 @@ def quantized_model(self) -> nn.Module:
939956# class references
940957quantizer_class_dict = {
941958 "embedding" : EmbeddingOnlyQuantHandler ,
942- "embedding:wx" : IntxWeightEmbeddingQuantizer ,
959+ "embedding:wx" : EmbeddingQuantizer ,
943960 "linear:int8" : WeightOnlyInt8QuantHandler ,
944961 "precision" : PrecisionHandler ,
945962 "executor" : ExecutorHandler ,
@@ -979,5 +996,19 @@ def quantized_model(self) -> nn.Module:
979996 except Exception as e :
980997 print ("Unable to load torchao mps ops library." )
981998
999+ torchao_experimental_mps_op_lib_spec = importlib .util .spec_from_file_location (
1000+ "torchao_experimental_mps_op_lib" ,
1001+ f"{ torchao_build_path } /src/ao/torchao/experimental/ops/mps/mps_op_lib.py" ,
1002+ )
1003+ torchao_experimental_mps_op_lib = importlib .util .module_from_spec (
1004+ torchao_experimental_mps_op_lib_spec
1005+ )
1006+ sys .modules ["torchao_experimental_mps_op_lib" ] = torchao_experimental_mps_op_lib
1007+ torchao_experimental_mps_op_lib_spec .loader .exec_module (
1008+ torchao_experimental_mps_op_lib
1009+ )
1010+ from torchao_experimental_mps_op_lib import *
1011+
1012+
9821013except Exception as e :
9831014 print ("Unable to import torchao experimental quant_api with error: " , e )
0 commit comments