1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from abc import abstractmethod
15+
1516import torch
1617from neural_compressor .torch .algorithms .fp8_quant ._quant_common .quant_config import get_hqt_config
17- from .scale_method_factory import ScaleMethodFactory , QuantTensorName
18- from .scales_method import QuantTensorType
1918from ..common import ModuleConfig
19+ from .scale_method_factory import QuantTensorName , ScaleMethodFactory
20+ from .scales_method import QuantTensorType
2021from ..quant_dequant import DequantOutput , QuantDequant , QuantDequantNone , QuantInput , QuantDynamicInput
21- from ..fp_utils import scale_fcn
2222
2323
2424class BaseOpQuantizer :
@@ -58,16 +58,17 @@ def init_scales_from_module_config(self, module):
5858
5959 def calc_input_scales (self , num_of_inputs ):
6060 input_scales = []
61- for i in range (num_of_inputs ):
61+ for i in range (num_of_inputs ):
6262 input_measurement = self .measurement .inputs [i ] if self .measurement is not None else []
6363 input_scales .append (
64- self .inputs_scales_creators [i ].calc_scales (input_measurement , QuantTensorType .MEASUREMENTS ))
64+ self .inputs_scales_creators [i ].calc_scales (input_measurement , QuantTensorType .MEASUREMENTS )
65+ )
6566 return input_scales
6667
6768 def calc_output_scales (self ):
6869 output_measurement = self .measurement .outputs [0 ] if self .measurement is not None else []
6970 output_scales = self .output_scales_creators [0 ].calc_scales (output_measurement , QuantTensorType .MEASUREMENTS )
70- return (output_scales , )
71+ return (output_scales ,)
7172
7273 def init_input_config (self , scales_inv , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant ):
7374 if use_qdq or fake_quant :
@@ -104,22 +105,27 @@ def get_scales_module_config(self):
104105 weight_scales_in_ch = self .weight_ich_scale_calc .calc_scales (input_scales [0 ], QuantTensorType .CONST )
105106 rescaled_weight = torch .div (self .mod .weight , weight_scales_in_ch .reshape ([1 , - 1 ]))
106107 weights_scales_out_ch = self .weight_och_scale_calc .calc_scales (rescaled_weight , QuantTensorType .CONST )
107- params_config = {"weight" : weights_scales_out_ch } if (
108- self .weight_ich_scale_calc is None ) \
109- else {"weight" : {0 : weights_scales_out_ch , 1 : weight_scales_in_ch }}
110- output_scales = self .output_scales_creators [0 ].calc_scales (output_measurement , QuantTensorType .MEASUREMENTS ,
111- input0 = weights_scales_out_ch , input1 = input_scales [0 ])
108+ params_config = (
109+ {"weight" : weights_scales_out_ch }
110+ if (self .weight_ich_scale_calc is None )
111+ else {"weight" : {0 : weights_scales_out_ch , 1 : weight_scales_in_ch }}
112+ )
113+ output_scales = self .output_scales_creators [0 ].calc_scales (
114+ output_measurement , QuantTensorType .MEASUREMENTS , input0 = weights_scales_out_ch , input1 = input_scales [0 ]
115+ )
112116 return ModuleConfig (
113117 input_scales ,
114118 (output_scales ,),
115- params_config ,
119+ params_config ,
116120 )
117121
118122 def init_weight_config (self , scales , scales_inv , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant ):
119123 if use_qdq :
120124 # to ensure the weights to be loaded to the device in fp8
121- weight_config = [QuantInput (scales_inv , lp_dtype , hp_dtype , scale_format = scale_format , use_qdq = use_qdq ),
122- DequantOutput (scales , lp_dtype , hp_dtype , scale_format = scale_format , use_qdq = use_qdq )]
125+ weight_config = [
126+ QuantInput (scales_inv , lp_dtype , hp_dtype , scale_format = scale_format , use_qdq = use_qdq ),
127+ DequantOutput (scales , lp_dtype , hp_dtype , scale_format = scale_format , use_qdq = use_qdq ),
128+ ]
123129 elif fake_quant :
124130 weight_config = [QuantDequant (scales_inv , lp_dtype , hp_dtype , scale_format = scale_format )]
125131 else :
@@ -137,7 +143,14 @@ def scales_module_config_to_q_and_dq(self, module):
137143 self .init_scales_from_module_config (module )
138144 self .init_weights_from_module (module .params ["weight" ])
139145 scale_format , use_qdq , fake_quant , lp_dtype , hp_dtype = self .get_module_configuration ()
140- input_config = super ().init_input_config ((self .inputs_scales_creators [0 ].calc_invert_scales (),), lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant )
146+ input_config = super ().init_input_config (
147+ (self .inputs_scales_creators [0 ].calc_invert_scales (),),
148+ lp_dtype ,
149+ hp_dtype ,
150+ scale_format ,
151+ use_qdq ,
152+ fake_quant ,
153+ )
141154 # outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here
142155 output_config = [QuantDequantNone (lp_dtype , hp_dtype , scale_format = scale_format )]
143156 weight_config = self .init_weight_config (
@@ -172,70 +185,82 @@ def get_scales_module_config(self):
172185 input_scales = self .calc_input_scales (num_of_inputs = 2 )
173186
174187 output_scales = input_scales [0 ] * input_scales [1 ]
175- return ModuleConfig (
176- input_scales ,
177- (output_scales ,),
178- {}
179- )
188+ return ModuleConfig (input_scales , (output_scales ,), {})
180189
181190 def scales_module_config_to_q_and_dq (self , module ):
182191 self .init_scales_from_module_config (module )
183192 scale_format , use_qdq , fake_quant , lp_dtype , hp_dtype = super ().get_module_configuration ()
184- input_config = super ().init_input_config ((self .inputs_scales_creators [0 ].calc_invert_scales (),
185- self .inputs_scales_creators [1 ].calc_invert_scales ()),
186- lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant )
193+ input_config = super ().init_input_config (
194+ (self .inputs_scales_creators [0 ].calc_invert_scales (), self .inputs_scales_creators [1 ].calc_invert_scales ()),
195+ lp_dtype ,
196+ hp_dtype ,
197+ scale_format ,
198+ use_qdq ,
199+ fake_quant ,
200+ )
187201 # outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here
188202 output_config = [QuantDequantNone (lp_dtype , hp_dtype , scale_format = scale_format )]
189203 return ModuleConfig (input_config , output_config )
190204
205+
191206class SoftmaxOpQuantizer (BaseOpQuantizer ):
192207
193208 def __init__ (self , config , mod , measurement , params , module_type ):
194- super ().__init__ ( config , mod , measurement , params , module_type )
209+ super ().__init__ (config , mod , measurement , params , module_type )
195210 self .output_scales_creators .append (self .scales_method_factory .get_scale_method (QuantTensorName .OUTPUT ))
196211
197212 def get_scales_module_config (self ):
198213 output_scales = self .calc_output_scales ()
199214
200- return ModuleConfig ((),output_scales )
215+ return ModuleConfig ((), output_scales )
201216
202217 def scales_module_config_to_q_and_dq (self , module ):
203218 self .init_scales_from_module_config (module )
204219 scale_format , use_qdq , fake_quant , lp_dtype , hp_dtype = super ().get_module_configuration ()
205- output_config = [DequantOutput (self .output_scales_creators [0 ].scale , lp_dtype , hp_dtype , scale_format = scale_format )]
220+ output_config = [
221+ DequantOutput (self .output_scales_creators [0 ].scale , lp_dtype , hp_dtype , scale_format = scale_format )
222+ ]
206223 return ModuleConfig ([], output_config , {})
207224
225+
208226class FsdpaOpQuantizer (BaseOpQuantizer ):
209227
210228 def __init__ (self , config , mod , measurement , params , module_type ):
211229 super ().__init__ (config , mod , measurement , params , module_type )
212230 self .num_of_inputs = 4
213- self .inputs_scales_creators = [self .scales_method_factory .get_scale_method (QuantTensorName .INPUT )
214- for i in range (self .num_of_inputs )]
231+ self .inputs_scales_creators = [
232+ self .scales_method_factory .get_scale_method (QuantTensorName .INPUT ) for i in range (self .num_of_inputs )
233+ ]
215234 self .output_scales_creators .append (self .scales_method_factory .get_scale_method (QuantTensorName .OUTPUT ))
216235
217236 def get_scales_module_config (self ):
218237 # 3 inputs calcs from input measurement
219238 input_scales = self .calc_input_scales (num_of_inputs = self .num_of_inputs - 1 )
220239 # one input calcs from output measurement
221240 output1_measurement = self .measurement .outputs [1 ] if self .measurement is not None else []
222- input_scales .append (self .inputs_scales_creators [self .num_of_inputs - 1 ].calc_scales (output1_measurement , QuantTensorType .MEASUREMENTS ))
223- output_scales = self .calc_output_scales ()
224- return ModuleConfig (
225- input_scales ,
226- output_scales ,
227- {}
241+ input_scales .append (
242+ self .inputs_scales_creators [self .num_of_inputs - 1 ].calc_scales (
243+ output1_measurement , QuantTensorType .MEASUREMENTS
244+ )
228245 )
246+ output_scales = self .calc_output_scales ()
247+ return ModuleConfig (input_scales , output_scales , {})
248+
229249 def scales_module_config_to_q_and_dq (self , module ):
230250 self .init_scales_from_module_config (module )
231251 scale_format , use_qdq , fake_quant , lp_dtype , hp_dtype = super ().get_module_configuration ()
232- input_scales_inv = [self .inputs_scales_creators [i ].calc_invert_scales () for i in range (len (self .inputs_scales_creators ))]
252+ input_scales_inv = [
253+ self .inputs_scales_creators [i ].calc_invert_scales () for i in range (len (self .inputs_scales_creators ))
254+ ]
233255 input_config = super ().init_input_config (
234- input_scales_inv
235- , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant )
236- output_config = [DequantOutput (self .output_scales_creators [0 ].scale , lp_dtype , hp_dtype , scale_format = scale_format )]
256+ input_scales_inv , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant
257+ )
258+ output_config = [
259+ DequantOutput (self .output_scales_creators [0 ].scale , lp_dtype , hp_dtype , scale_format = scale_format )
260+ ]
237261 return ModuleConfig (input_config , output_config , {})
238262
263+
239264class KVCacheOpQuantizer (BaseOpQuantizer ):
240265
241266 def __init__ (self , config , mod , measurement , params , module_type ):
@@ -247,20 +272,20 @@ def get_scales_module_config(self):
247272 input_scales = self .calc_input_scales (num_of_inputs = 1 )
248273 self .output_scales_creators [0 ].scale = self .inputs_scales_creators [0 ].scale
249274 output_scales = [self .output_scales_creators [0 ].scale ]
250- return ModuleConfig (
251- input_scales ,
252- output_scales ,
253- {}
254- )
275+ return ModuleConfig (input_scales , output_scales , {})
255276
256277 def scales_module_config_to_q_and_dq (self , module ):
257278 self .init_scales_from_module_config (module )
258279 scale_format , use_qdq , fake_quant , lp_dtype , hp_dtype = super ().get_module_configuration ()
259- input_scales_inv = [self .inputs_scales_creators [i ].calc_invert_scales () for i in range (len (self .inputs_scales_creators ))]
280+ input_scales_inv = [
281+ self .inputs_scales_creators [i ].calc_invert_scales () for i in range (len (self .inputs_scales_creators ))
282+ ]
260283 input_config = super ().init_input_config (
261- input_scales_inv
262- , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant )
263- output_config = [DequantOutput (self .output_scales_creators [0 ].scale , lp_dtype , hp_dtype , scale_format = scale_format )]
284+ input_scales_inv , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant
285+ )
286+ output_config = [
287+ DequantOutput (self .output_scales_creators [0 ].scale , lp_dtype , hp_dtype , scale_format = scale_format )
288+ ]
264289 return ModuleConfig (input_config , output_config )
265290
266291
@@ -280,25 +305,26 @@ def get_scales_module_config(self):
280305 num_of_inputs = len (self .measurement .inputs ) if self .measurement is not None else 1
281306 num_of_experts = self .mod .num_experts if self .mod .num_experts is not None else 8
282307 input_scales = self .calc_input_scales (num_of_inputs = num_of_inputs )
283- for i in range (num_of_experts ):
284- output_measurement = self .measurement .outputs [i + 1 ] if self .measurement is not None else []
308+ for i in range (num_of_experts ):
309+ output_measurement = self .measurement .outputs [i + 1 ] if self .measurement is not None else []
285310 input_scales .append (
286- self .inputs_scales_creators [num_of_inputs + i ].calc_scales (output_measurement , QuantTensorType .MEASUREMENTS ))
287- output_scales = self .calc_output_scales ()
288- return ModuleConfig (
289- input_scales ,
290- output_scales ,
291- {}
311+ self .inputs_scales_creators [num_of_inputs + i ].calc_scales (
312+ output_measurement , QuantTensorType .MEASUREMENTS
313+ )
292314 )
315+ output_scales = self .calc_output_scales ()
316+ return ModuleConfig (input_scales , output_scales , {})
293317
294318 def scales_module_config_to_q_and_dq (self , module ):
295319 self .init_scales_from_module_config (module )
296320 scale_format , use_qdq , fake_quant , lp_dtype , hp_dtype = super ().get_module_configuration ()
297- input_scales_inv = [self .inputs_scales_creators [i ].calc_invert_scales () for i in range (len (self .inputs_scales_creators ))]
321+ input_scales_inv = [
322+ self .inputs_scales_creators [i ].calc_invert_scales () for i in range (len (self .inputs_scales_creators ))
323+ ]
298324 input_config = super ().init_input_config (
299- input_scales_inv
300- , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant )
301- output_config = [QuantDequantNone (lp_dtype , hp_dtype , scale_format = scale_format )]
325+ input_scales_inv , lp_dtype , hp_dtype , scale_format , use_qdq , fake_quant
326+ )
327+ output_config = [QuantDequantNone (lp_dtype , hp_dtype , scale_format = scale_format )]
302328 return ModuleConfig (input_config , output_config )
303329
304330
0 commit comments