Skip to content

Commit 708cfb2

Browse files
linoybuXuehaoSun
authored andcommitted
[SW-218197] fix bug in Mixtral unitscale (#139)
Signed-off-by: Xin He <[email protected]>
1 parent de7c531 commit 708cfb2

File tree

1 file changed

+85
-59
lines changed
  • neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods

1 file changed

+85
-59
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 85 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import abstractmethod
15+
1516
import torch
1617
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import get_hqt_config
17-
from .scale_method_factory import ScaleMethodFactory, QuantTensorName
18-
from .scales_method import QuantTensorType
1918
from ..common import ModuleConfig
19+
from .scale_method_factory import QuantTensorName, ScaleMethodFactory
20+
from .scales_method import QuantTensorType
2021
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput
21-
from ..fp_utils import scale_fcn
2222

2323

2424
class BaseOpQuantizer:
@@ -58,16 +58,17 @@ def init_scales_from_module_config(self, module):
5858

5959
def calc_input_scales(self, num_of_inputs):
6060
input_scales = []
61-
for i in range(num_of_inputs):
61+
for i in range(num_of_inputs):
6262
input_measurement = self.measurement.inputs[i] if self.measurement is not None else []
6363
input_scales.append(
64-
self.inputs_scales_creators[i].calc_scales(input_measurement, QuantTensorType.MEASUREMENTS))
64+
self.inputs_scales_creators[i].calc_scales(input_measurement, QuantTensorType.MEASUREMENTS)
65+
)
6566
return input_scales
6667

6768
def calc_output_scales(self):
6869
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
6970
output_scales = self.output_scales_creators[0].calc_scales(output_measurement, QuantTensorType.MEASUREMENTS)
70-
return (output_scales, )
71+
return (output_scales,)
7172

7273
def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant):
7374
if use_qdq or fake_quant:
@@ -104,22 +105,27 @@ def get_scales_module_config(self):
104105
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
105106
rescaled_weight = torch.div(self.mod.weight, weight_scales_in_ch.reshape([1, -1]))
106107
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST)
107-
params_config = {"weight": weights_scales_out_ch} if (
108-
self.weight_ich_scale_calc is None) \
109-
else {"weight": {0: weights_scales_out_ch, 1: weight_scales_in_ch}}
110-
output_scales = self.output_scales_creators[0].calc_scales(output_measurement, QuantTensorType.MEASUREMENTS,
111-
input0=weights_scales_out_ch, input1=input_scales[0])
108+
params_config = (
109+
{"weight": weights_scales_out_ch}
110+
if (self.weight_ich_scale_calc is None)
111+
else {"weight": {0: weights_scales_out_ch, 1: weight_scales_in_ch}}
112+
)
113+
output_scales = self.output_scales_creators[0].calc_scales(
114+
output_measurement, QuantTensorType.MEASUREMENTS, input0=weights_scales_out_ch, input1=input_scales[0]
115+
)
112116
return ModuleConfig(
113117
input_scales,
114118
(output_scales,),
115-
params_config,
119+
params_config,
116120
)
117121

118122
def init_weight_config(self, scales, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant):
119123
if use_qdq:
120124
# to ensure the weights to be loaded to the device in fp8
121-
weight_config = [QuantInput(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq),
122-
DequantOutput(scales, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq)]
125+
weight_config = [
126+
QuantInput(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq),
127+
DequantOutput(scales, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq),
128+
]
123129
elif fake_quant:
124130
weight_config = [QuantDequant(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format)]
125131
else:
@@ -137,7 +143,14 @@ def scales_module_config_to_q_and_dq(self, module):
137143
self.init_scales_from_module_config(module)
138144
self.init_weights_from_module(module.params["weight"])
139145
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = self.get_module_configuration()
140-
input_config = super().init_input_config((self.inputs_scales_creators[0].calc_invert_scales(),), lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant)
146+
input_config = super().init_input_config(
147+
(self.inputs_scales_creators[0].calc_invert_scales(),),
148+
lp_dtype,
149+
hp_dtype,
150+
scale_format,
151+
use_qdq,
152+
fake_quant,
153+
)
141154
# outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here
142155
output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)]
143156
weight_config = self.init_weight_config(
@@ -172,70 +185,82 @@ def get_scales_module_config(self):
172185
input_scales = self.calc_input_scales(num_of_inputs=2)
173186

174187
output_scales = input_scales[0] * input_scales[1]
175-
return ModuleConfig(
176-
input_scales,
177-
(output_scales,),
178-
{}
179-
)
188+
return ModuleConfig(input_scales, (output_scales,), {})
180189

181190
def scales_module_config_to_q_and_dq(self, module):
182191
self.init_scales_from_module_config(module)
183192
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration()
184-
input_config = super().init_input_config((self.inputs_scales_creators[0].calc_invert_scales(),
185-
self.inputs_scales_creators[1].calc_invert_scales()),
186-
lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant)
193+
input_config = super().init_input_config(
194+
(self.inputs_scales_creators[0].calc_invert_scales(), self.inputs_scales_creators[1].calc_invert_scales()),
195+
lp_dtype,
196+
hp_dtype,
197+
scale_format,
198+
use_qdq,
199+
fake_quant,
200+
)
187201
# outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here
188202
output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)]
189203
return ModuleConfig(input_config, output_config)
190204

205+
191206
class SoftmaxOpQuantizer(BaseOpQuantizer):
192207

193208
def __init__(self, config, mod, measurement, params, module_type):
194-
super().__init__( config, mod, measurement, params, module_type)
209+
super().__init__(config, mod, measurement, params, module_type)
195210
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT))
196211

197212
def get_scales_module_config(self):
198213
output_scales = self.calc_output_scales()
199214

200-
return ModuleConfig((),output_scales)
215+
return ModuleConfig((), output_scales)
201216

202217
def scales_module_config_to_q_and_dq(self, module):
203218
self.init_scales_from_module_config(module)
204219
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration()
205-
output_config = [DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)]
220+
output_config = [
221+
DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)
222+
]
206223
return ModuleConfig([], output_config, {})
207224

225+
208226
class FsdpaOpQuantizer(BaseOpQuantizer):
209227

210228
def __init__(self, config, mod, measurement, params, module_type):
211229
super().__init__(config, mod, measurement, params, module_type)
212230
self.num_of_inputs = 4
213-
self.inputs_scales_creators = [self.scales_method_factory.get_scale_method(QuantTensorName.INPUT)
214-
for i in range(self.num_of_inputs)]
231+
self.inputs_scales_creators = [
232+
self.scales_method_factory.get_scale_method(QuantTensorName.INPUT) for i in range(self.num_of_inputs)
233+
]
215234
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT))
216235

217236
def get_scales_module_config(self):
218237
# 3 inputs calcs from input measurement
219238
input_scales = self.calc_input_scales(num_of_inputs=self.num_of_inputs - 1)
220239
# one input calcs from output measurement
221240
output1_measurement = self.measurement.outputs[1] if self.measurement is not None else []
222-
input_scales.append(self.inputs_scales_creators[self.num_of_inputs-1].calc_scales(output1_measurement, QuantTensorType.MEASUREMENTS))
223-
output_scales = self.calc_output_scales()
224-
return ModuleConfig(
225-
input_scales,
226-
output_scales,
227-
{}
241+
input_scales.append(
242+
self.inputs_scales_creators[self.num_of_inputs - 1].calc_scales(
243+
output1_measurement, QuantTensorType.MEASUREMENTS
244+
)
228245
)
246+
output_scales = self.calc_output_scales()
247+
return ModuleConfig(input_scales, output_scales, {})
248+
229249
def scales_module_config_to_q_and_dq(self, module):
230250
self.init_scales_from_module_config(module)
231251
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration()
232-
input_scales_inv = [self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))]
252+
input_scales_inv = [
253+
self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))
254+
]
233255
input_config = super().init_input_config(
234-
input_scales_inv
235-
, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant)
236-
output_config = [DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)]
256+
input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant
257+
)
258+
output_config = [
259+
DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)
260+
]
237261
return ModuleConfig(input_config, output_config, {})
238262

263+
239264
class KVCacheOpQuantizer(BaseOpQuantizer):
240265

241266
def __init__(self, config, mod, measurement, params, module_type):
@@ -247,20 +272,20 @@ def get_scales_module_config(self):
247272
input_scales = self.calc_input_scales(num_of_inputs=1)
248273
self.output_scales_creators[0].scale = self.inputs_scales_creators[0].scale
249274
output_scales = [self.output_scales_creators[0].scale]
250-
return ModuleConfig(
251-
input_scales,
252-
output_scales,
253-
{}
254-
)
275+
return ModuleConfig(input_scales, output_scales, {})
255276

256277
def scales_module_config_to_q_and_dq(self, module):
257278
self.init_scales_from_module_config(module)
258279
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration()
259-
input_scales_inv = [self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))]
280+
input_scales_inv = [
281+
self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))
282+
]
260283
input_config = super().init_input_config(
261-
input_scales_inv
262-
, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant)
263-
output_config = [DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)]
284+
input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant
285+
)
286+
output_config = [
287+
DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)
288+
]
264289
return ModuleConfig(input_config, output_config)
265290

266291

@@ -280,25 +305,26 @@ def get_scales_module_config(self):
280305
num_of_inputs = len(self.measurement.inputs) if self.measurement is not None else 1
281306
num_of_experts = self.mod.num_experts if self.mod.num_experts is not None else 8
282307
input_scales = self.calc_input_scales(num_of_inputs=num_of_inputs)
283-
for i in range(num_of_experts):
284-
output_measurement = self.measurement.outputs[i+1] if self.measurement is not None else []
308+
for i in range(num_of_experts):
309+
output_measurement = self.measurement.outputs[i + 1] if self.measurement is not None else []
285310
input_scales.append(
286-
self.inputs_scales_creators[num_of_inputs + i].calc_scales(output_measurement, QuantTensorType.MEASUREMENTS))
287-
output_scales = self.calc_output_scales()
288-
return ModuleConfig(
289-
input_scales,
290-
output_scales,
291-
{}
311+
self.inputs_scales_creators[num_of_inputs + i].calc_scales(
312+
output_measurement, QuantTensorType.MEASUREMENTS
313+
)
292314
)
315+
output_scales = self.calc_output_scales()
316+
return ModuleConfig(input_scales, output_scales, {})
293317

294318
def scales_module_config_to_q_and_dq(self, module):
295319
self.init_scales_from_module_config(module)
296320
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration()
297-
input_scales_inv = [self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))]
321+
input_scales_inv = [
322+
self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))
323+
]
298324
input_config = super().init_input_config(
299-
input_scales_inv
300-
, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant)
301-
output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)]
325+
input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant
326+
)
327+
output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)]
302328
return ModuleConfig(input_config, output_config)
303329

304330

0 commit comments

Comments
 (0)