Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):

enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True)
enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True)
n_blocks = self.recipes.get("awq_args", {}).get("n_blocks", 5)
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
model = awq_quantize(
model,
Expand All @@ -1672,7 +1671,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
n_samples=calib_sampling_size,
enable_auto_scale=enable_auto_scale,
enable_mse_search=enable_mse_search,
n_blocks=n_blocks,
)
elif "RTN" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
Expand All @@ -1684,33 +1682,42 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
return model

def _dump_model_op_stats(self, model, tune_cfg):
import re

fp32_op_list = self.query_handler.get_op_types_by_precision(precision="weight_only_integer")

res = {}
# collect all dtype info and build empty results with existing op_type
for optype in fp32_op_list:
res[optype] = {}

dtype_set = set()
for op, config in tune_cfg["op"].items():
op_type = op[1]
if not config["weight"]["dtype"] == "fp32":
num_bits = config["weight"]["bits"]
group_size = config["weight"]["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
dtype_set.add("FP32")
dtype_list = list(dtype_set)
dtype_list.sort()
for op, config in tune_cfg["op"].items():
op_type = op[1]
if op_type not in res.keys():
res[op_type] = {dtype: 0 for dtype in dtype_list}

# fill in results with op_type and dtype
for op, config in tune_cfg["op"].items():
if config["weight"]["dtype"] == "fp32":
res[op_type]["FP32"] += 1
for node in model.nodes():
if node.op_type == "MatMulWithQuantWeight":
optype = "MatMul"
else:
num_bits = config["weight"]["bits"]
group_size = config["weight"]["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1
optype = node.op_type

if optype not in res:
continue
if re.fullmatch("^.*_Q\d*G\d*", node.input[1]):
search_out = re.search("_Q\d*", node.input[1])
dtype = "A32W{}G{}".format(
node.input[1][search_out.start() + 2 : search_out.end()], node.input[1][search_out.end() + 1 :]
)
else:
dtype = "FP32"
dtype_set.add(dtype)

if dtype in res[optype]:
res[optype][dtype] += 1
else:
res[optype][dtype] = 1

dtype_list = list(dtype_set)
for dtype in dtype_list:
for optype in res.keys():
if dtype not in res[optype]:
res[optype][dtype] = 0

# update stats format for dump.
field_names = ["Op Type", "Total"]
Expand Down Expand Up @@ -1760,7 +1767,7 @@ def query_fw_capability(self, model):
precisions = query.get_precisions()

for precision in precisions:
if precision != "weight_only_integer":
if precision not in ["weight_only_integer", "fp32"]:
continue
# get supported optype for target precision
optypes = (
Expand All @@ -1785,7 +1792,7 @@ def query_fw_capability(self, model):
continue
else:
op_capability = copy.deepcopy(configs[op])
op_capability["activation"]["quant_mode"] = "weight_only"
op_capability["activation"]["quant_mode"] = "weight_only"
if op not in optype_wise.keys():
optype_wise[op] = [op_capability]
elif op_capability not in optype_wise[op]:
Expand Down
1 change: 0 additions & 1 deletion neural_compressor/adaptor/onnxrt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
'dtype': ['fp32']
}
},
'Attention': *cap_weight_only_matmul
}
int8: &ref_1_6 {
'static': &ref_1_6_static {
Expand Down
Loading