Skip to content

Commit 744bf4c

Browse files
mengniwang95Mengni Wang
andauthored
[SW-233731] Use torchao op for CPU QDQ and abstract QDQ calling (#264)
Abstract QDQ calling Fix QDQ model print issue Use torchao op for CPU QDQ (HPU doesn't has this accuracy issue) --------- Signed-off-by: Mengni Wang <[email protected]> Co-authored-by: Mengni Wang <[email protected]>
1 parent 8bb9758 commit 744bf4c

File tree

10 files changed

+245
-54
lines changed

10 files changed

+245
-54
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@
2525
from .._core.scale_handler import add_scale_registry, get_scale_dtype
2626
from .._quant_common.quant_config import ScaleFormat
2727
from .common import QuantTensorType
28-
from .fp_utils import (
29-
quantize_per_tensor_to_fp8,
30-
dequantize_per_tensor_from_fp8,
31-
quantize_per_channel_to_fp8,
32-
dequantize_per_channel_from_fp8,
33-
invert_scale,
34-
)
3528
from .scale_handler import create_scale_tensor
3629

3730

@@ -93,17 +86,13 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
9386
self.register_scale("scale_inv", scale_inv, self.scale_format)
9487
if self.use_qdq:
9588
self.register_scale("scale", 1 / self.scale_inv, self.scale_format)
96-
self.quantize_op = (
97-
quantize_per_channel_to_fp8
98-
if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1
99-
else quantize_per_tensor_to_fp8
100-
)
101-
89+
op_type = OP_TYPE.QUANT_PC if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1 else OP_TYPE.QUANT
10290
else:
103-
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
91+
op_type = OP_TYPE.CAST_TO_FP8
92+
self.quantize_op = get_quantized_func_wrapper(op_type, self.scale_format)
10493

10594
def forward(self, x):
106-
return self.cast_to_op(x, self.scale_inv, False, False, self.lp_dtype)
95+
return self.quantize_op(x, self.scale_inv, False, False, self.lp_dtype)
10796

10897
def forward_qdq(self, x):
10998
return self.quantize_op(
@@ -153,16 +142,13 @@ def __init__(self, scale, lp_dtype, hp_dtype, *args, **kwargs):
153142
super(DequantOutput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs)
154143
self.register_scale("scale", scale, self.scale_format)
155144
if self.use_qdq:
156-
self.dequantize_op = (
157-
dequantize_per_channel_from_fp8
158-
if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1
159-
else dequantize_per_tensor_from_fp8
160-
)
145+
op_type = OP_TYPE.DEQUANT_PC if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1 else OP_TYPE.DEQUANT
161146
else:
162-
self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
147+
op_type = OP_TYPE.CAST_FROM_FP8
148+
self.dequantize_op = get_quantized_func_wrapper(op_type, self.scale_format)
163149

164150
def forward(self, x):
165-
return self.cast_from_op(x, self.scale, self.hp_dtype)
151+
return self.dequantize_op(x, self.scale, self.hp_dtype)
166152

167153
def forward_qdq(self, x):
168154
return self.dequantize_op(
@@ -187,30 +173,37 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
187173
super(QuantDequant, self).__init__(lp_dtype, hp_dtype, *args, **kwargs)
188174
self.register_scale("scale_inv", scale_inv, self.scale_format)
189175
self.register_scale("scale", 1 / scale_inv, self.scale_format)
190-
if not self.use_qdq:
191-
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
192-
self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
176+
self.quantize_op = (
177+
get_quantized_func_wrapper(OP_TYPE.QUANT, self.scale_format)
178+
if self.use_qdq
179+
else get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
180+
)
181+
self.dequantize_op = (
182+
get_quantized_func_wrapper(OP_TYPE.DEQUANT, self.scale_format)
183+
if self.use_qdq
184+
else get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
185+
)
193186

194187
def forward(self, x, *args, **kwargs):
195-
y = self.cast_to_op(x, self.scale_inv, False, False, self.lp_dtype)
188+
y = self.quantize_op(x, self.scale_inv, False, False, self.lp_dtype)
196189
# mark_step is needed so fuser won't remove 2 consecutive casts.
197190
# will be removed once SW-196431 is implemented
198191
# Call cur_accelerator.synchronize() which will call mark_step() as well
199192
cur_accelerator.synchronize()
200-
z = self.cast_from_op(y, self.scale, self.hp_dtype)
193+
z = self.dequantize_op(y, self.scale, self.hp_dtype)
201194
cur_accelerator.synchronize()
202195
return z
203196

204197
def forward_qdq(self, x, *args, **kwargs):
205-
y = quantize_per_tensor_to_fp8(
198+
y = self.quantize_op(
206199
x,
207200
scale=self.scale,
208201
zero_point=self.zero_point,
209202
quant_min=self.quant_min,
210203
quant_max=self.quant_max,
211204
dtype=self.lp_dtype,
212205
)
213-
z = dequantize_per_tensor_from_fp8(
206+
z = self.dequantize_op(
214207
y,
215208
scale=self.scale,
216209
zero_point=self.zero_point,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..quantized_func_wrapper import QuantizedFuncWrapperBase, OP_TYPE, QuantizedFuncWrapperFactory
16+
17+
import torch
18+
import torchao
19+
20+
from abc import ABCMeta
21+
22+
23+
24+
class QuantizedCPUFuncWrapperBase(QuantizedFuncWrapperBase, metaclass=ABCMeta):
25+
"""
26+
Placeholder for base class for CPU quantized func wrapper.
27+
"""
28+
def __init__(self, scale_format, is_dynamic=False):
29+
self._quantized_func_ = self.get_default_quantized_func()
30+
31+
32+
class QuantizedCPUQuant(QuantizedCPUFuncWrapperBase):
33+
34+
def get_default_quantized_func(self):
35+
return torch.ops.torchao.quantize_affine_float8
36+
37+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
38+
return self._quantized_func_(tensor=input, scale=scale, float8_dtype=dtype)
39+
40+
41+
class QuantizedCPUQuantPC(QuantizedCPUFuncWrapperBase):
42+
43+
def get_default_quantized_func(self):
44+
return torch.ops.torchao.quantize_affine_float8
45+
46+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
47+
return self._quantized_func_(tensor=input, scale=scale.view((-1, 1)), float8_dtype=dtype)
48+
49+
50+
class QuantizedCPUDeQuant(QuantizedCPUFuncWrapperBase):
51+
52+
def get_default_quantized_func(self):
53+
return torch.ops.torchao.dequantize_affine_float8
54+
55+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
56+
return self._quantized_func_(tensor=input, scale=scale, output_dtype=out_dtype)
57+
58+
59+
class QuantizedCPUDeQuantPC(QuantizedCPUFuncWrapperBase):
60+
61+
def get_default_quantized_func(self):
62+
return torch.ops.torchao.dequantize_affine_float8
63+
64+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
65+
return self._quantized_func_(tensor=input, scale=scale.view((1, -1)), output_dtype=out_dtype)
66+
67+
68+
_OP_TYPE_CPU_QUANTIZED_WRAPPER_CLASSES = {
69+
OP_TYPE.QUANT: QuantizedCPUQuant,
70+
OP_TYPE.DEQUANT: QuantizedCPUDeQuant,
71+
OP_TYPE.QUANT_PC: QuantizedCPUQuantPC,
72+
OP_TYPE.DEQUANT_PC: QuantizedCPUDeQuantPC,
73+
}
74+
75+
76+
def init_cpu_quantized_func_wrapper_factory():
77+
QuantizedFuncWrapperFactory.initialize(_OP_TYPE_CPU_QUANTIZED_WRAPPER_CLASSES)

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,54 @@ def get_dynamic_quantized_func(self):
174174
return torch.ops.hpu.mixture_of_experts.fp8_fused_weights_dynamic
175175

176176

177+
class QuantizedHPUQuant(QuantizedHpuFuncWrapperBase):
178+
179+
def get_default_quantized_func(self):
180+
return torch.ops.quantized_decomposed.quantize_per_tensor
181+
182+
def get_scalar_quantized_func(self):
183+
return self.get_default_quantized_func()
184+
185+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
186+
return self._quantized_func_(input, scale, zero_point, quant_min, quant_max, dtype=dtype)
187+
188+
189+
class QuantizedHPUDeQuant(QuantizedHpuFuncWrapperBase):
190+
191+
def get_default_quantized_func(self):
192+
return torch.ops.quantized_decomposed.dequantize_per_tensor
193+
194+
def get_scalar_quantized_func(self):
195+
return self.get_default_quantized_func()
196+
197+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
198+
return self._quantized_func_(input, scale, zero_point, quant_min, quant_max, dtype=dtype, out_dtype=out_dtype)
199+
200+
201+
class QuantizedHPUQuantPC(QuantizedHpuFuncWrapperBase):
202+
203+
def get_default_quantized_func(self):
204+
return torch.ops.quantized_decomposed.quantize_per_channel
205+
206+
def get_scalar_quantized_func(self):
207+
return self.get_default_quantized_func()
208+
209+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
210+
return self._quantized_func_(input, scale, zero_point, axis, quant_min, quant_max, dtype=dtype)
211+
212+
213+
class QuantizedHPUDeQuantPC(QuantizedHpuFuncWrapperBase):
214+
215+
def get_default_quantized_func(self):
216+
return torch.ops.quantized_decomposed.dequantize_per_channel
217+
218+
def get_scalar_quantized_func(self):
219+
return self.get_default_quantized_func()
220+
221+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
222+
return self._quantized_func_(input, scale, zero_point, axis, quant_min, quant_max, dtype=dtype, out_dtype=out_dtype)
223+
224+
177225
_OP_TYPE_HPU_QUANTIZED_WRAPPER_CLASSES = {OP_TYPE.LINEAR_GEMM : QuantizedHpuMatmul,
178226
OP_TYPE.MATMUL_GEMM: QuantizedHpuMatmul,
179227
OP_TYPE.SOFTMAX : QuantizedHpuSoftmax,
@@ -183,6 +231,10 @@ def get_dynamic_quantized_func(self):
183231
OP_TYPE.CAST_FROM_FP8 : QuantizedHPUCastFromFP8,
184232
OP_TYPE.DYNAMIC_MOE: QuantizedHpuDynamicMoe,
185233
OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS: QuantizedHpuDynamicMoeFusedWeights,
234+
OP_TYPE.QUANT: QuantizedHPUQuant,
235+
OP_TYPE.DEQUANT: QuantizedHPUDeQuant,
236+
OP_TYPE.QUANT_PC: QuantizedHPUQuantPC,
237+
OP_TYPE.DEQUANT_PC: QuantizedHPUDeQuantPC,
186238
}
187239

188240
def init_hpu_quantized_func_wrapper_factory():

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/quantized_func_wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class OP_TYPE(Enum):
2727
CAST_FROM_FP8 = auto()
2828
DYNAMIC_MOE = auto()
2929
DYNAMIC_MOE_FUSED_WEIGHTS = auto()
30+
QUANT = auto()
31+
DEQUANT = auto()
32+
QUANT_PC = auto()
33+
DEQUANT_PC = auto()
3034

3135

3236
class QuantizedFuncWrapperBase(ABC):

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/quantized_func_wrapper_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def init_quantized_func_wrapper_factory():
3838
from .xpu.xpu_quantized_func_wrapper import init_xpu_quantized_func_wrapper_factory
3939
init_xpu_quantized_func_wrapper_factory()
4040
elif device_name == "cpu":
41-
# only support QDQ now
42-
pass
41+
from .cpu.cpu_quantized_func_wrapper import init_cpu_quantized_func_wrapper_factory
42+
init_cpu_quantized_func_wrapper_factory()
4343
else:
4444
raise ValueError("Unknown device type - {}".format(device_name))
4545

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/xpu/xpu_quantized_func_wrapper.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,62 @@ def get_default_quantized_func(self):
6363
return torch.ops.torch_ipex.cast_from_fp8
6464

6565

66+
class QuantizedXPUQuant(QuantizedXpuFuncWrapperBase):
67+
68+
def get_default_quantized_func(self):
69+
return torch.ops.quantized_decomposed.quantize_per_tensor
70+
71+
def get_scalar_quantized_func(self):
72+
return self.get_default_quantized_func()
73+
74+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
75+
return self._quantized_func_(input, scale, zero_point, quant_min, quant_max, dtype=dtype)
76+
77+
78+
class QuantizedXPUDeQuant(QuantizedXpuFuncWrapperBase):
79+
80+
def get_default_quantized_func(self):
81+
return torch.ops.quantized_decomposed.dequantize_per_tensor
82+
83+
def get_scalar_quantized_func(self):
84+
return self.get_default_quantized_func()
85+
86+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
87+
return self._quantized_func_(input, scale, zero_point, quant_min, quant_max, dtype=dtype, out_dtype=out_dtype)
88+
89+
90+
class QuantizedXPUQuantPC(QuantizedXpuFuncWrapperBase):
91+
92+
def get_default_quantized_func(self):
93+
return torch.ops.quantized_decomposed.quantize_per_channel
94+
95+
def get_scalar_quantized_func(self):
96+
return self.get_default_quantized_func()
97+
98+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
99+
return self._quantized_func_(input, scale, zero_point, axis, quant_min, quant_max, dtype=dtype)
100+
101+
102+
class QuantizedXPUDeQuantPC(QuantizedXpuFuncWrapperBase):
103+
104+
def get_default_quantized_func(self):
105+
return torch.ops.quantized_decomposed.dequantize_per_channel
106+
107+
def get_scalar_quantized_func(self):
108+
return self.get_default_quantized_func()
109+
110+
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
111+
return self._quantized_func_(input, scale, zero_point, axis, quant_min, quant_max, dtype=dtype, out_dtype=out_dtype)
112+
66113
_OP_TYPE_XPU_QUANTIZED_WRAPPER_CLASSES = {
67114
OP_TYPE.LINEAR_GEMM : QuantizedXPUMatmul,
68115
OP_TYPE.MATMUL_GEMM : QuantizedXPUMatmul,
69116
OP_TYPE.CAST_TO_FP8 : QuantizedXPUCastToFP8Base,
70-
OP_TYPE.CAST_FROM_FP8 : QuantizedXPUCastFromFP8Base
117+
OP_TYPE.CAST_FROM_FP8 : QuantizedXPUCastFromFP8Base,
118+
OP_TYPE.QUANT: QuantizedXPUQuant,
119+
OP_TYPE.DEQUANT: QuantizedXPUDeQuant,
120+
OP_TYPE.QUANT_PC: QuantizedXPUQuantPC,
121+
OP_TYPE.DEQUANT_PC: QuantizedXPUDeQuantPC,
71122
}
72123

73124

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def extra_repr(self) -> str:
138138
return extra_representation(
139139
self.extra_repr_org(),
140140
self.class_name_org,
141-
get_current_repr(self, "scale_input", "scale_other"),
141+
get_current_repr(self, "scale_input", "scale_other") if not self.use_qdq else "",
142142
)
143143

144144

@@ -227,7 +227,7 @@ def extra_repr(self) -> str:
227227
return extra_representation(
228228
self.extra_repr_org(),
229229
self.class_name_org,
230-
get_current_repr(self, "scale_input", "scale_weight"),
230+
get_current_repr(self, "scale_input", "scale_weight") if not self.use_qdq else "",
231231
)
232232

233233

@@ -1135,7 +1135,7 @@ def extra_repr(self) -> str:
11351135
return extra_representation(
11361136
self.extra_repr_org(),
11371137
self.class_name_org,
1138-
get_current_repr(self, "scale_input", "scale_weight"),
1138+
get_current_repr(self, "scale_input", "scale_weight") if not self.use_qdq else "",
11391139
)
11401140

11411141

@@ -1171,7 +1171,7 @@ def extra_repr(self) -> str:
11711171
return extra_representation(
11721172
self.extra_repr_org(),
11731173
self.class_name_org,
1174-
get_current_repr(self, "scale_input", "scale_output"),
1174+
get_current_repr(self, "scale_input", "scale_output") if not self.use_qdq else "",
11751175
)
11761176

11771177

@@ -1252,7 +1252,7 @@ def extra_repr(self) -> str:
12521252
return extra_representation(
12531253
self.extra_repr_org(),
12541254
self.class_name_org,
1255-
get_current_repr(self, "scale_input", "scale_weight"),
1255+
get_current_repr(self, "scale_input", "scale_weight") if not self.use_qdq else "",
12561256
)
12571257

12581258

0 commit comments

Comments
 (0)