Skip to content

Commit 20f931c

Browse files
Tiefen-boopXuehaoSun
authored andcommitted
[SW-224874] Implement support for hp/lp dtypes in KV-cache QDQ (#222)
1 parent 0347186 commit 20f931c

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,16 @@ def __init__(self, config, mod, measurement, params, module_type):
333333
self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT))
334334
self.output_scales_creators.append(self.inputs_scales_creators[0])
335335

336+
# TODO: Remove after implementing lp_dtype in OHF.
337+
def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant):
338+
input_config = super().init_input_config(scales_inv, lp_dtype, hp_dtype, scale_format, False, fake_quant)
339+
if use_qdq:
340+
input_config.extend([
341+
QuantDequant(s_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq)
342+
for s_inv in scales_inv
343+
])
344+
return input_config
345+
336346
def get_scales_module_config(self):
337347
input_scales = self.calc_input_scales(num_of_inputs=1)
338348
self.output_scales_creators[0].scale = self.inputs_scales_creators[0].scale
@@ -345,11 +355,13 @@ def scales_module_config_to_q_and_dq(self, module):
345355
input_scales_inv = [
346356
self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators))
347357
]
348-
input_config = super().init_input_config(
358+
# TODO: After implementing lp_dtype in OHF can call:
359+
# `super().init_input_config(scales_inv, lp_dtype, hp_dtype, scale_format, False, fake_quant)`
360+
input_config = self.init_input_config(
349361
input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant
350362
)
351363
output_config = [
352-
DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format)
364+
DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=False)
353365
]
354366
return ModuleConfig(input_config, output_config)
355367

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
869869
self.quant_input = self._mod_extra_config.inputs[0]
870870
self.dequant_output = self._mod_extra_config.outputs[0]
871871
if self.use_qdq:
872+
self.qdq_input = self._mod_extra_config.inputs[1]
872873
self.update = self.update_qdq
873874
mod.update = self.update_qdq
874875
else:
@@ -885,8 +886,23 @@ def allocate(self, inp_seq_len, dtype, device, shape):
885886

886887
# overwrite update function of original module to force quant and dequant of cache input and output
887888
def update_qdq(self, prev, cur, dim, idx, inp_seq_len):
888-
qinput = self.quant_input(cur)
889-
output = self.org_update(prev, qinput, dim, idx, inp_seq_len)
889+
"""
890+
Explanation: If we want to optimize index_copy so it would run in fp8 instead of bf16
891+
we need the tensors to be in fp8 before calling index_copy.
892+
Also the `prev` and `curr` tensors need to be of the same dtype - and quanting them both
893+
from bf16 is no help, best we can do is have prev be initialized an fp8 tensor from the start.
894+
Since the initilization of `prev` is done in OHF (and that is not implemented yet) we
895+
currently need to support both options until the implementation in OHF is done, then
896+
can we remove the support for the bf16 `prev` option (the else here).
897+
"""
898+
if prev.dtype == torch.float8_e4m3fn:
899+
qcurr = self.quant_input(cur)
900+
qoutput = self.org_update(prev, qcurr, dim, idx, inp_seq_len)
901+
output = self.dequant_output(qoutput)
902+
# TODO: remove the `else` part once the lp_dtype is implemented in OHF
903+
else:
904+
curr = self.qdq_input(cur)
905+
output = self.org_update(prev, curr, dim, idx, inp_seq_len)
890906
return output
891907

892908
# overwrite update function of original module to force quant and dequant of cache input and output

0 commit comments

Comments
 (0)