From b81277ce72bb0f10a692d5e1f78a0606bb544ecc Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Thu, 13 Jun 2024 12:47:59 +0800 Subject: [PATCH 1/5] fix 3x ipex static quant regression Signed-off-by: Cheng, Zixuan --- .../torch/algorithms/smooth_quant/utility.py | 55 ++++++++++++++++++- .../torch/algorithms/static_quant/utility.py | 47 +++++----------- 2 files changed, 68 insertions(+), 34 deletions(-) diff --git a/neural_compressor/torch/algorithms/smooth_quant/utility.py b/neural_compressor/torch/algorithms/smooth_quant/utility.py index 7dc647dbc95..e2ce9c97f5b 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/utility.py +++ b/neural_compressor/torch/algorithms/smooth_quant/utility.py @@ -26,8 +26,8 @@ from neural_compressor.torch.algorithms.static_quant import ( CpuInfo, + Statistics, TransformerBasedModelBlockPatternDetector, - dump_model_op_stats, generate_activation_observer, get_quantizable_ops_from_cfgs, ipex_config_path, @@ -251,6 +251,59 @@ def cfg_to_qconfig( return None +def dump_model_op_stats(user_cfg): + """This is a function to dump quantizable ops of model to user. + + Args: + user_cfg (dict): quantization config + Returns: + None + """ + res = dict() + for k, v in user_cfg.items(): + op_type_list = k[-1].split("><") + op_type = "" + for op in op_type_list: + if "class" in op: + op_type = ( + op[op.rfind(".") + 1 : op.rfind("'")] + if op_type == "" + else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")] + ) + elif "method" in op: + start = op.find("'") + 1 + if start > 1: + op_type = ( + op[start : op.find("'", start)] + if op_type == "" + else op_type + "&" + op[start : op.find("'", start)] + ) + else: + start = op.find("method") + 7 + op_type = ( + op[start : op.find(" ", start)] + if op_type == "" + else op_type + "&" + op[start : op.find(" ", start)] + ) + else: + op_type = op if op_type == "" else op_type + "&" + op + if op_type not in res.keys(): + res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0} + if v["weight"]["dtype"] == "int8": + res[op_type]["INT8"] += 1 + elif v["weight"]["dtype"] == "fp32": + res[op_type]["FP32"] += 1 + + output_data = [ + [op_type, sum(res[op_type].values()), res[op_type]["INT8"], res[op_type]["BF16"], res[op_type]["FP32"]] + for op_type in res.keys() + ] + + Statistics( + output_data, header="Mixed Precision Statistics", field_names=["Op Type", "Total", "INT8", "BF16", "FP32"] + ).print_stat() + + def get_parent(node, all_parents=False): # pragma: no cover if node.inputs() is None: return None diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index f90471539fd..6301db36865 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -47,9 +47,14 @@ "": "add", # for IPEX >= 2.2 "": "AdaptiveAvgPool2d", "Linear_Relu": "Linear", + "Linear_add": "Linear", "": "Linear", "": "MaxPool2d", - "re": {"" method = ipex_op_type.split("'")[1] op_name_info.append((module_fqn, method)) - elif "Convolution" in ipex_op_type: # "Convolution_Relu" - op_name_info.append((module_fqn, "Conv2d")) + elif "_" in ipex_op_type: # "Convolution_Relu", "Linear_Relu" + op_name_info.append((module_fqn, ipex_op_type.split("_")[0])) else: re_flag = False for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): @@ -394,32 +400,7 @@ def dump_model_op_stats(user_cfg): """ res = dict() for k, v in user_cfg.items(): - op_type_list = k[-1].split("><") - op_type = "" - for op in op_type_list: - if "class" in op: - op_type = ( - op[op.rfind(".") + 1 : op.rfind("'")] - if op_type == "" - else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")] - ) - elif "method" in op: - start = op.find("'") + 1 - if start > 1: - op_type = ( - op[start : op.find("'", start)] - if op_type == "" - else op_type + "&" + op[start : op.find("'", start)] - ) - else: - start = op.find("method") + 7 - op_type = ( - op[start : op.find(" ", start)] - if op_type == "" - else op_type + "&" + op[start : op.find(" ", start)] - ) - else: - op_type = op if op_type == "" else op_type + "&" + op + op_type = k[1] if op_type not in res.keys(): res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0} if v["weight"]["dtype"] == "int8": From ee6bd67c172c867b54d7206852b273cfde62973a Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Thu, 13 Jun 2024 12:49:55 +0800 Subject: [PATCH 2/5] add ut Signed-off-by: Cheng, Zixuan --- .../torch/quantization/test_static_quant.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 60d0b205371..be7ad2659d5 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -22,13 +22,18 @@ class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.fc1 = torch.nn.Linear(30, 50) - self.fc2 = torch.nn.Linear(50, 30) - self.fc3 = torch.nn.Linear(30, 5) + self.fc2 = torch.nn.Linear(50, 50) + self.fc3 = torch.nn.Linear(50, 30) + self.fc4 = torch.nn.Linear(30, 5) + self.relu = torch.nn.ReLU() def forward(self, x): out = self.fc1(x) out = self.fc2(out) + out = self.relu(out) out = self.fc3(out) + out = out + x + out = self.fc4(out) return out model = Model() @@ -52,6 +57,7 @@ def teardown_class(self): def test_static_quant_default(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() + quant_config.excluded_precisions = ["bf16"] example_inputs = self.input prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) @@ -69,6 +75,7 @@ def test_static_quant_default(self): def test_static_quant_fallback(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() + quant_config.excluded_precisions = ["bf16"] example_inputs = self.input # fallback by op_type quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) @@ -78,21 +85,23 @@ def test_static_quant_fallback(self): assert q_model is not None, "Quantization failed!" for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items(): - if op_info["op_type"] == "": + if op_info["op_type"] == "Linear": dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"] assert dtype == "torch.float32", "Failed to fallback linear op, please check!" # fallback by op_name - quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + quant_config = get_default_static_config() + quant_config.excluded_precisions = ["bf16"] + quant_config.set_local("fc2", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items(): - if op_info["fqn"] == "fc1": + if op_info["fqn"] == "fc2": dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"] - assert dtype == "torch.float32", "Failed to fallback fc1 layer, please check!" + assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!" @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") @pytest.mark.parametrize( @@ -106,7 +115,7 @@ def test_static_quant_fallback(self): ) def test_static_quant_params(self, act_sym, act_algo): fp32_model = copy.deepcopy(self.fp32_model) - quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo) + quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"]) example_inputs = self.input prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) @@ -133,7 +142,7 @@ def run_fn(model): fp32_model = copy.deepcopy(model) fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]])) example_inputs = torch.zeros(3, 2) - quant_config = StaticQuantConfig(act_sym=True, act_algo="kl") + quant_config = StaticQuantConfig(act_sym=True, act_algo="kl", excluded_precisions=["bf16"]) prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model) @@ -175,6 +184,7 @@ def run_fn(model): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() + quant_config.excluded_precisions = ["bf16"] prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model) @@ -195,6 +205,7 @@ def test_static_quant_with_quantize_API(self): # quantize API fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() + quant_config.excluded_precisions = ["bf16"] example_inputs = self.input q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" From 0a4468e6d5db04db33f3ea9fe8d9a07c09c22b54 Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Thu, 13 Jun 2024 12:52:06 +0800 Subject: [PATCH 3/5] Update test_static_quant.py --- test/3x/torch/quantization/test_static_quant.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index be7ad2659d5..a64d64dfd20 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -57,7 +57,6 @@ def teardown_class(self): def test_static_quant_default(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() - quant_config.excluded_precisions = ["bf16"] example_inputs = self.input prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) @@ -75,7 +74,6 @@ def test_static_quant_default(self): def test_static_quant_fallback(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() - quant_config.excluded_precisions = ["bf16"] example_inputs = self.input # fallback by op_type quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) @@ -91,7 +89,6 @@ def test_static_quant_fallback(self): # fallback by op_name quant_config = get_default_static_config() - quant_config.excluded_precisions = ["bf16"] quant_config.set_local("fc2", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) @@ -184,7 +181,6 @@ def run_fn(model): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() - quant_config.excluded_precisions = ["bf16"] prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model) @@ -205,7 +201,6 @@ def test_static_quant_with_quantize_API(self): # quantize API fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() - quant_config.excluded_precisions = ["bf16"] example_inputs = self.input q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" From 50066609788a1272e4a73e149783fe5904c62bfc Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Thu, 13 Jun 2024 12:52:41 +0800 Subject: [PATCH 4/5] Update test_static_quant.py --- test/3x/torch/quantization/test_static_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index a64d64dfd20..46e791aa52f 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -112,7 +112,7 @@ def test_static_quant_fallback(self): ) def test_static_quant_params(self, act_sym, act_algo): fp32_model = copy.deepcopy(self.fp32_model) - quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"]) + quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo) example_inputs = self.input prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) @@ -139,7 +139,7 @@ def run_fn(model): fp32_model = copy.deepcopy(model) fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]])) example_inputs = torch.zeros(3, 2) - quant_config = StaticQuantConfig(act_sym=True, act_algo="kl", excluded_precisions=["bf16"]) + quant_config = StaticQuantConfig(act_sym=True, act_algo="kl") prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model) From ae8eaf67ecb60b70971aa6757ed4084e432a9d4e Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Thu, 13 Jun 2024 14:46:38 +0800 Subject: [PATCH 5/5] Update utility.py --- neural_compressor/torch/algorithms/static_quant/utility.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 6301db36865..81133557b3e 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -43,6 +43,7 @@ "": "Conv2d", "": "Conv3d", "": "ReLU", + "": "EmbeddingBag", "": "add", # for IPEX < 2.2 "": "add", # for IPEX >= 2.2 "": "AdaptiveAvgPool2d",