From c036ee04af7a22551b088b1f12965c3e2a81c13c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 14 Jun 2024 20:00:49 +0800 Subject: [PATCH 1/7] support set_local for pt2e Signed-off-by: yiliu30 --- .../pt2e_quant/half_precision_rewriter.py | 3 +- .../torch/algorithms/pt2e_quant/utility.py | 24 +++++++++- neural_compressor/torch/utils/environ.py | 3 ++ test/3x/torch/quantization/test_pt2e_quant.py | 44 ++++++++++++++++--- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index cc91ba70bec..62c97c3aaef 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -137,7 +137,8 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule): for node in gm.graph.nodes: if meta := getattr(node, "meta"): if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY): - if quantization_annotation._annotated: + none_annotation = xiq._X86InductorQuantizationAnnotation(_annotated=True) + if quantization_annotation != none_annotation: continue unquantized_node_set.add(node) return unquantized_node_set diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index 92635db1f70..f49e0671149 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -20,6 +20,8 @@ from torch.ao.quantization.quantizer import QuantizationSpec from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer +from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2 + def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8} @@ -53,6 +55,9 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig: + NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"] + if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: + return None default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic) input_act_quant_spec = create_quant_spec_from_config( inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic @@ -75,5 +80,22 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct # set global global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) quantizer.set_global(global_config) - # Skip the local config for now (need torch 2.4) + # Skip the local config for now (need torch >= 2.3.2) + if GT_TORCH_VERSION_2_3_2: + op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() + if op_type_config_dict: + for op_type, config in op_type_config_dict.items(): + _nn_module_type = getattr(torch.nn, op_type, None) + if _nn_module_type: + quantizer.set_module_type_qconfig( + _nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic) + ) + _nn_func_type = getattr(torch.nn.functional, op_type, None) + if _nn_func_type: + quantizer.set_function_type_qconfig( + _nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic) + ) + if op_name_config_dict: + for op_name, config in op_name_config_dict.items(): + quantizer.set_module_name_qconfig(op_name, _map_inc_config_to_torch_quant_config(config, is_dynamic)) return quantizer diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index 3091aa83d88..0697979996d 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -91,6 +91,9 @@ def get_torch_version(): return version +GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2") + + def get_accelerator(device_name="auto"): global accelerator # update the global accelerator when calling this func from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 3857832598a..710b0d28008 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -17,7 +17,7 @@ prepare, quantize, ) -from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version +from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version torch.manual_seed(0) @@ -119,6 +119,42 @@ def calib_fn(model): logger.warning("out shape is %s", out.shape) assert out is not None + @pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2") + def test_quantize_simple_model_with_set_local(self, force_not_import_ipex): + model, example_inputs = self.build_simple_torch_model_and_example_inputs() + float_model_output = model(*example_inputs) + quant_config = None + + def calib_fn(model): + for i in range(4): + model(*example_inputs) + + quant_config = get_default_static_config() + quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn) + + # check the half node + expected_node_occurrence = { + # Only quantize the `fc2` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + } + expected_node_occurrence = { + torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items() + } + node_in_graph = self.get_node_in_graph(q_model) + for node, cnt in expected_node_occurrence.items(): + assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}" + + from torch._inductor import config + + config.freezing = True + q_model_out = q_model(*example_inputs) + assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!" + opt_model = torch.compile(q_model) + out = opt_model(*example_inputs) + assert out is not None + @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") @pytest.mark.parametrize("is_dynamic", [False, True]) def test_prepare_and_convert_on_simple_model(self, is_dynamic, force_not_import_ipex): @@ -193,7 +229,7 @@ def get_node_in_graph(graph_module): nodes_in_graph[n] += 1 else: nodes_in_graph[n] = 1 - return + return nodes_in_graph @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") def test_mixed_fp16_and_int8(self, force_not_import_ipex): @@ -221,9 +257,7 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex): } node_in_graph = self.get_node_in_graph(converted_model) for node, cnt in expected_node_occurrence.items(): - assert ( - expected_node_occurrence.get(node, 0) == cnt - ), f"Node {node} should occur {cnt} times, but {node_in_graph[node]}" + assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}" # inference from torch._inductor import config From b9ca5e9ee951f8ab3ef3648d6c8af11c0f857d71 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 17 Jun 2024 09:36:23 +0800 Subject: [PATCH 2/7] fixed ut Signed-off-by: yiliu30 --- test/3x/torch/quantization/test_pt2e_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 710b0d28008..e2c643f07c6 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -231,7 +231,7 @@ def get_node_in_graph(graph_module): nodes_in_graph[n] = 1 return nodes_in_graph - @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") + @pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0") def test_mixed_fp16_and_int8(self, force_not_import_ipex): model, example_inputs = self.build_model_include_conv_and_linear() model = export(model, example_inputs=example_inputs) From f82b9c90f927f6c93df4d4381310e43c65bbeca2 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 17 Jun 2024 10:35:16 +0800 Subject: [PATCH 3/7] use np 1.26 Signed-off-by: yiliu30 --- requirements_pt.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_pt.txt b/requirements_pt.txt index 6a012a75b5a..16ddea835d4 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -1,4 +1,4 @@ -numpy +numpy==1.26.4 peft==0.10.0 psutil py-cpuinfo From 0f3618a2cfa880e533f0ac7c0ae05818647742e7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 18 Jun 2024 14:01:38 +0800 Subject: [PATCH 4/7] fix coverage Signed-off-by: yiliu30 --- .coverage | Bin 53248 -> 53248 bytes .../pt2e_quant/half_precision_rewriter.py | 4 ++-- .../torch/algorithms/pt2e_quant/utility.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.coverage b/.coverage index 02b5b52790be5bd5bebab44b6f435921b28a3e22..36885e6d20c557bf2ac1b0c2aa4e030d551bce95 100644 GIT binary patch delta 1106 zcmb`Fy-OoO7{+IJXOXCzHwG+h#7a9SdUuUy{1X&}U?Do>HlrEcWcGF^7Xk*7(oRjM zg<_?pNv(y&H3yeeT3FZxC5RS&Ab#B(Tp8}T#dZdM56|;^OFKqs$5{KKe_K6BZ4qsx z8Z#r7RY$N6gWF?aAq+vp`*s+1%$MOFfoXrI@FZxcn_YkG1fwZ7?qBm<1PAn|Knba_Z#YM!a${`uE)sF zRHKl1T)&%r-F&;$Ug}fN9iXfK0Ue=>_*wn<^jhoCa3~Z-Nc(*~7@7#4;1;f+sWMKW Ma=vZWTHC#`U#c#hivR!s delta 796 zcmZozz}&Eac|sCXaKgrv+4_t|n+#Y5c=_5H__y)r@=Nnw;A`J3C{W6$;Lggz&?u~* zk)NBYUy+%Yk*c4QpPZNzpO_3}737yp&i9ivC@sm%Db@#(nI)Ba1(lvGKuzNKH5oxP zCFkcAmn7zu6a%#dFax!z;@4&g(Uw?Rk`Hu4a%xU$QDRAc5m2`~6VWbB%`3|+%FhF8 z@?ZpN62)(&GQ`UG_{_Y_lK6O_E-nTJ23Ecf2LA2*GJF^L^Y}V|fmOyQ@4?b2%!uDG zjId($Vrdj7pb1M@u>`U-suIwKJFJ+AcPS#Qn1Gr@Iq`ds6&zNZzr+hD@UikYGVp)r zf6Tv`e<^<>Fog2?>y=np7&+Zo85sUfJHf)hz*wWsko@F7rvn262SWoeRDlde9u}Y) zDWL2R`TvK{>oEg)am+xun6EhuKW<5@Fg%$F1fA2pC#m@9sdzqmA~H#(ut-zIQ|O?} z0d^rKpn5eX28Od|&bTnBI4A@tGqE$WaB>PUG%yQ+V~+nm1OIRSFZ^%$pYh)Zy6Y@I QI~!0FBP;9XXY-{D07C-=>;M1& diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index 62c97c3aaef..4122610592f 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -164,11 +164,11 @@ def _parse_node_candidate_set_from_user_config(config, gm): op_name_filters = [] for op_type_name, config in op_type_configs.items(): op_type = getattr(torch.nn, op_type_name) - if config.act_dtype == "fp16": + if config.act_dtype == "fp16": # pragma: no cover filter = xpq._get_module_type_filter(op_type) op_type_filters.append(filter) for op_name, config in op_name_configs.items(): - if config.act_dtype == "fp16": + if config.act_dtype == "fp16": # pragma: no cover filter = xpq._get_module_name_filter(op_name) op_name_filters.append(filter) node_set_from_user_config = set() diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index f49e0671149..82a49fa55ad 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -80,8 +80,8 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct # set global global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) quantizer.set_global(global_config) - # Skip the local config for now (need torch >= 2.3.2) - if GT_TORCH_VERSION_2_3_2: + # need torch >= 2.3.2 + if GT_TORCH_VERSION_2_3_2: # pragma: no cover op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() if op_type_config_dict: for op_type, config in op_type_config_dict.items(): From d8021e2f3f45cafbaa6f30c3c38dbd51a062e726 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 18 Jun 2024 14:38:52 +0800 Subject: [PATCH 5/7] remove .coverage Signed-off-by: yiliu30 --- .coverage | Bin 53248 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .coverage diff --git a/.coverage b/.coverage deleted file mode 100644 index 36885e6d20c557bf2ac1b0c2aa4e030d551bce95..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeI)&2QUe90zbacAKSb>al5xs%rGPu&!F$&>@7_7*L>xiAm#Su){Pw$9dX#i|y2Q z`XW)WY)BJG9L5nL&T!=q;J|?k5*sHDoVbi3Kob%hnl=W%pI@40Yo=bx*73C($FZNs z&+~kqx43To=IckCNQzlMaIHu@rEFJJRe4beMN#_bRisz8baH3 zs?=o^Shft3b@AmlW~PtLh-1?)9hniyKC!n?&@wS0rbXZ{iiQlttW%f5@v4q(MUFQo zqMFP?H$z$FiH`QsRL21)$o7LO z=e`nHo?VmSxUg1BX#14Ur6?(F;Jf9>7j?&ri|K_<dwu$nO@j zTinP`3%lXES*I@BjGSvUiPiL+`KDa0{BhBaIUX0o^T!1jf?B{6{k^4Ucd81t zO1Tuy*XjF~X*DB1{%lg^n^UwJc>l3&`O?UUdSO0pD9mG4oG4t~D0DQM)!kj14j=E{ zZn&3pHWKbd^}6t9<1WLoMN&D`n%DN3a)x0A+|h)e2Be%m(ALJ8r96mA-LG`EwYbHE zk~#>D1zE7!re||k3AqfmzaRr^PEMvrlVRaOHzaC0Q-$8r%t*JUv)TmAsnyNC)=+Q0 zG(4=HE5ywxt*!Je-_eA!np|fX%?`QJ3Y$STklb&O%wn`h$*hm0N2oXx`QFm>aJS;5 z;moOYV_Pe)=S#bGsjZ|(@VLaUo{mEBo^}_wYa-Zq_8V4I6ZhtwcT>hkdwNR;cXi7+ zXThA}b?^BTxj4%ghGKn}27P{M9a-i*?JU14mnRmiI*ld`%L%B<(ioh~0D=QnEedQ^ z{ALulU;Y?3Po8_)`AtsH{+!uUsIfRWb??{wFmz~I#|P->*x{ zOXdcOaX6$jolwMixOY~1p?!{Una?q|FdFA6n_PI9KSrG{>tD3tx|UZ>=PRCule5y; z&z+OgNd^&$C`PGMhSWvxI%js$^eB&5+4(0MWWr~4t!mQUxl-1gaj_=vIbAXvy<$Z& za$G4mNSv9tVJ%vrumed4X)YZb)$Z2}xVFqGUh`G*HvTHt8d3A5ef!i_GwqDYIAex1 zsk>H|?et@=qmkTou-#sc(tu0fcGk8Nwp+u&ACnY#a7${ik|<=M((gKbV|;v*~5yRsZcX5DF)7yaN&*q|rrlpBG6TG~-aPmBrd+WvCY zx5KjI*=0{Q1FLS*f9E#>8HRq4geY6}IX`fsnj4lI(SB*pH!UwJM>34)Kke%}C3SLh zlWIl|JuYRLQhQgU)K)e(wVG9*HA%pB_^|_KAQ!3dGN9C*+~%{=Gk+u{<*JmFxw!>3 zO~-R0)1>71{(n#LvSNH}j1GQLygYdD;cw)~Ap{@*0SG_<0uX=z1Rwwb2q?;Fy>(pO z+^fg*`oG83TPHU;srL21zM{8QHaWHK*Z&BOg@t1MM_=YyHKmY;|fB*y_009U<00Izz00bcL5D4thb+xcZSI&NQ`_QlE zog3fY@Sl5JqkYe7%Gq1}H@AORy}bO^&5PGh-1+36U(~-Z9lf15a&%ne6y@rnA1{6V z!!V?7%y7kZZ|8?VvV%#$RG=8TIED(SI1Rwwb2tWV=5P$##AOHafY*B$;y8Ksl zJ$}=)_$`;y=~n>p{eNZq7A+RC4FL#100Izz00bZa0SG_<0uX?}h6MQjKd%2bM1?>Q zfB*y_009U<00Izz00bZafh{Y*|Nq}({LFvDHPKmY;|fB*y_009U<00Izz a00bbgH3hUDde*cB= From 256fc0ad367b864c76646ba06eb4302b7bfb6adb Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 19 Jun 2024 10:24:49 +0800 Subject: [PATCH 6/7] disable coverage check Signed-off-by: yiliu30 --- .../torch/algorithms/pt2e_quant/half_precision_rewriter.py | 6 +++--- neural_compressor/torch/algorithms/pt2e_quant/utility.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index 4122610592f..b00f33b25cc 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -106,7 +106,7 @@ def get_filter_fn(node_list, fn): def is_target_node_in_candidate_list(match, original_graph, pattern_graph): """Filter the node with target operator in match and check if it is in `node_list`.""" target_node = None - for node in pattern_graph.nodes: + for node in pattern_graph.nodes: # pragma: no cover if node.target == target_op: target_node = node break @@ -162,7 +162,7 @@ def _parse_node_candidate_set_from_user_config(config, gm): op_type_configs, op_name_configs = config._get_op_name_op_type_config() op_type_filters = [] op_name_filters = [] - for op_type_name, config in op_type_configs.items(): + for op_type_name, config in op_type_configs.items(): # pragma: no cover op_type = getattr(torch.nn, op_type_name) if config.act_dtype == "fp16": # pragma: no cover filter = xpq._get_module_type_filter(op_type) @@ -173,7 +173,7 @@ def _parse_node_candidate_set_from_user_config(config, gm): op_name_filters.append(filter) node_set_from_user_config = set() all_filters = op_type_filters + op_name_filters - for node in gm.graph.nodes: + for node in gm.graph.nodes: # pragma: no cover if any([filter(node) for filter in all_filters]): node_set_from_user_config.add(node) return node_set_from_user_config diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index 82a49fa55ad..e4efd62271e 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -56,7 +56,7 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig: NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"] - if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: + if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover return None default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic) input_act_quant_spec = create_quant_spec_from_config( From 4b303f54d85d4275a546753f627c710b09f7c1a6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 20 Jun 2024 20:31:42 +0800 Subject: [PATCH 7/7] disable some check Signed-off-by: yiliu30 --- .../torch/algorithms/pt2e_quant/half_precision_rewriter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index b00f33b25cc..bd1865e674c 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -110,7 +110,7 @@ def is_target_node_in_candidate_list(match, original_graph, pattern_graph): if node.target == target_op: target_node = node break - if target_node is None: + if target_node is None: # pragma: no cover return False matched_node = match.nodes_map[target_node] return matched_node in node_list @@ -138,7 +138,7 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule): if meta := getattr(node, "meta"): if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY): none_annotation = xiq._X86InductorQuantizationAnnotation(_annotated=True) - if quantization_annotation != none_annotation: + if quantization_annotation != none_annotation: # pragma: no cover continue unquantized_node_set.add(node) return unquantized_node_set