Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,9 @@ def to_config_mapping(
if self.global_config is not None:
config_mapping[(op_name, op_type)] = global_config
if op_type in op_type_config_dict:
config_mapping[(op_name, op_type)] = op_name_config_dict[op_type]
config_mapping[(op_name, op_type)] = op_type_config_dict[op_type]
for op_name_pattern in op_name_config_dict:
if isinstance(op_name, str) and re.match(op_name_pattern, op_name):
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
elif op_name_pattern == op_name: # TODO: map ipex opname to stock pt op_name
if re.match(op_name_pattern, op_name):
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
return config_mapping

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
"""
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."

_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)

# check smoothquant folding value
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
Expand Down Expand Up @@ -121,7 +121,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg)
dump_model_op_stats(tune_cfg["op"])
return model


Expand Down Expand Up @@ -185,7 +185,7 @@ def qdq_quantize(
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg)
dump_model_op_stats(tune_cfg["op"])
return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
Returns:
A quantized model.
"""
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)
# update json file in ipex_config_path; map ipex op_name to pt op_name
user_cfg = cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
Expand Down Expand Up @@ -82,7 +83,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(tune_cfg)
dump_model_op_stats(user_cfg)
return model


Expand Down
38 changes: 32 additions & 6 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import re
from collections import OrderedDict
from typing import Dict, List, Union

import torch
Expand Down Expand Up @@ -66,9 +67,10 @@
def cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name): # pragma: no cover
assert cfgs is not None, "No configure for IPEX int8 model..."
op_infos = copy.deepcopy(op_infos_from_cfgs)
cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name)
cfgs, user_cfg = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name)
with open(ipex_config_path, "w") as write_f:
json.dump(cfgs, write_f, indent=4)
return user_cfg


def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover
Expand All @@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
Returns:
cfgs (dict): updated configs.
"""
tmp_user_cfg = OrderedDict()
for op in user_cfg: # map ipex op_name to pt op_name
for i, op_name in enumerate(op):
for ops, _ in op_infos_from_cfgs.items():
if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name:
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
break
user_cfg = tmp_user_cfg
for op_name in user_cfg:
inc_op_cfg = user_cfg[op_name]
for i, name in enumerate(op_name[0]):
Expand Down Expand Up @@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
else:
pass
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
return cfgs
return cfgs, user_cfg


def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
Expand Down Expand Up @@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
cfgs (dict): dict of configuration
"""
quantizable_ops = []
op_name_info = []
# group ops by position for transform-based model
detector = TransformerBasedModelBlockPatternDetector(model)
detect_result = detector.detect_block()
Expand Down Expand Up @@ -277,17 +289,30 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
if ipex_op_type in unify_op_type_mapping_ipex:
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type]))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
if "class" in ipex_op_type: # "<class 'torch.nn.modules.activation.ReLU'>"
op_type = ipex_op_type.split("'")[1]
op_name_info.append((module_fqn, eval(op_type)))
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
method = ipex_op_type.split("'")[1]
op_type = getattr(
torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method
)
op_name_info.append((module_fqn, op_type))
else:
op_name_info.append((module_fqn, op_type))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
if re.match(pattern, ipex_op_type):
re_flag = True
quantizable_ops.append((tuple(name), unify_op_type))
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
op_name_info.append((module_fqn, ipex_op_type))
break
if not re_flag:
quantizable_ops.append((tuple(name), ipex_op_type))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
op_name_info.append((module_fqn, ipex_op_type))
else:
op_type = ""
for op_name in name:
Expand All @@ -302,14 +327,15 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
_op_cfg_id = name[0][2]
module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"]
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn
op_name_info.append((module_fqn, op_type))

logger.debug("Map op name to fqn: ")
logger.debug(map_op_name_to_fqn)
logger.info("Attention Blocks : ")
logger.info(attention_block)
logger.info("FFN Blocks : ")
logger.info(ffn_blocks)
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, op_name_info


def simple_inference(q_model, example_inputs, iterations=1):
Expand All @@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1):
q_model(example_inputs)


def dump_model_op_stats(tune_cfg):
def dump_model_op_stats(user_cfg):
"""This is a function to dump quantizable ops of model to user.

Args:
tune_cfg (dict): quantization config
user_cfg (dict): quantization config
Returns:
None
"""
res = dict()
for k, v in tune_cfg["op"].items():
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively

model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@classmethod
Expand Down Expand Up @@ -923,7 +923,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively

model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@classmethod
Expand Down
15 changes: 15 additions & 0 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def test_static_quant_default(self):
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_static_quant_fallback(self):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_static_config()
example_inputs = self.input
# fallback by op_type
quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

# fallback by op_name
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.parametrize(
"act_sym, act_algo",
Expand Down