Skip to content

Commit 5ba9efe

Browse files
Improve op wise coverage for ORT WOQ (#1270)
* Enhance ORT WOQ Signed-off-by: Mengni Wang <[email protected]> * bug fix Signed-off-by: Mengni Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update onnxrt.py * Update test_weight_only_adaptor.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update weight_only.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update onnxrt.py * Update test_weight_only_adaptor.py * Update test_weight_only_adaptor.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Mengni Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 35f9461 commit 5ba9efe

File tree

5 files changed

+195
-173
lines changed

5 files changed

+195
-173
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16631663

16641664
enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True)
16651665
enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True)
1666-
n_blocks = self.recipes.get("awq_args", {}).get("n_blocks", 5)
16671666
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
16681667
model = awq_quantize(
16691668
model,
@@ -1672,7 +1671,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16721671
n_samples=calib_sampling_size,
16731672
enable_auto_scale=enable_auto_scale,
16741673
enable_mse_search=enable_mse_search,
1675-
n_blocks=n_blocks,
16761674
)
16771675
elif "RTN" in algos:
16781676
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
@@ -1684,33 +1682,42 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16841682
return model
16851683

16861684
def _dump_model_op_stats(self, model, tune_cfg):
1685+
import re
1686+
1687+
fp32_op_list = self.query_handler.get_op_types_by_precision(precision="weight_only_integer")
1688+
16871689
res = {}
1688-
# collect all dtype info and build empty results with existing op_type
1690+
for optype in fp32_op_list:
1691+
res[optype] = {}
1692+
16891693
dtype_set = set()
1690-
for op, config in tune_cfg["op"].items():
1691-
op_type = op[1]
1692-
if not config["weight"]["dtype"] == "fp32":
1693-
num_bits = config["weight"]["bits"]
1694-
group_size = config["weight"]["group_size"]
1695-
dtype_str = "A32W{}G{}".format(num_bits, group_size)
1696-
dtype_set.add(dtype_str)
1697-
dtype_set.add("FP32")
1698-
dtype_list = list(dtype_set)
1699-
dtype_list.sort()
1700-
for op, config in tune_cfg["op"].items():
1701-
op_type = op[1]
1702-
if op_type not in res.keys():
1703-
res[op_type] = {dtype: 0 for dtype in dtype_list}
1704-
1705-
# fill in results with op_type and dtype
1706-
for op, config in tune_cfg["op"].items():
1707-
if config["weight"]["dtype"] == "fp32":
1708-
res[op_type]["FP32"] += 1
1694+
for node in model.nodes():
1695+
if node.op_type == "MatMulWithQuantWeight":
1696+
optype = "MatMul"
17091697
else:
1710-
num_bits = config["weight"]["bits"]
1711-
group_size = config["weight"]["group_size"]
1712-
dtype_str = "A32W{}G{}".format(num_bits, group_size)
1713-
res[op_type][dtype_str] += 1
1698+
optype = node.op_type
1699+
1700+
if optype not in res:
1701+
continue
1702+
if re.fullmatch("^.*_Q\d*G\d*", node.input[1]):
1703+
search_out = re.search("_Q\d*", node.input[1])
1704+
dtype = "A32W{}G{}".format(
1705+
node.input[1][search_out.start() + 2 : search_out.end()], node.input[1][search_out.end() + 1 :]
1706+
)
1707+
else:
1708+
dtype = "FP32"
1709+
dtype_set.add(dtype)
1710+
1711+
if dtype in res[optype]:
1712+
res[optype][dtype] += 1
1713+
else:
1714+
res[optype][dtype] = 1
1715+
1716+
dtype_list = list(dtype_set)
1717+
for dtype in dtype_list:
1718+
for optype in res.keys():
1719+
if dtype not in res[optype]:
1720+
res[optype][dtype] = 0
17141721

17151722
# update stats format for dump.
17161723
field_names = ["Op Type", "Total"]
@@ -1760,7 +1767,7 @@ def query_fw_capability(self, model):
17601767
precisions = query.get_precisions()
17611768

17621769
for precision in precisions:
1763-
if precision != "weight_only_integer":
1770+
if precision not in ["weight_only_integer", "fp32"]:
17641771
continue
17651772
# get supported optype for target precision
17661773
optypes = (
@@ -1785,7 +1792,7 @@ def query_fw_capability(self, model):
17851792
continue
17861793
else:
17871794
op_capability = copy.deepcopy(configs[op])
1788-
op_capability["activation"]["quant_mode"] = "weight_only"
1795+
op_capability["activation"]["quant_mode"] = "weight_only"
17891796
if op not in optype_wise.keys():
17901797
optype_wise[op] = [op_capability]
17911798
elif op_capability not in optype_wise[op]:

neural_compressor/adaptor/onnxrt.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
'dtype': ['fp32']
3131
}
3232
},
33-
'Attention': *cap_weight_only_matmul
3433
}
3534
int8: &ref_1_6 {
3635
'static': &ref_1_6_static {

0 commit comments

Comments
 (0)