From 6bf424d83ff12bcb15624ffc399770e708bc1ab4 Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Fri, 19 Apr 2024 14:54:08 +0800 Subject: [PATCH 1/6] map ipex op_name w/ pt op_name Signed-off-by: Cheng, Zixuan --- neural_compressor/common/base_config.py | 2 +- .../algorithms/static_quant/static_quant.py | 4 +-- .../torch/algorithms/static_quant/utility.py | 33 +++++++++++++++++-- .../torch/quantization/config.py | 4 +-- .../torch/quantization/test_static_quant.py | 15 +++++++++ 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 0b7749a5c48..d953eaf47d0 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -410,7 +410,7 @@ def to_config_mapping( if self.global_config is not None: config_mapping[(op_name, op_type)] = global_config if op_type in op_type_config_dict: - config_mapping[(op_name, op_type)] = op_name_config_dict[op_type] + config_mapping[(op_name, op_type)] = op_type_config_dict[op_type] for op_name_pattern in op_name_config_dict: if isinstance(op_name, str) and re.match(op_name_pattern, op_name): config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 626d0f60a2e..a9169689924 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -51,7 +51,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): Returns: A quantized model. """ - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs) cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path model.eval() @@ -82,7 +82,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + # dump_model_op_stats(tune_cfg) return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index dd073f50aab..76edea7b69c 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -83,6 +83,18 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ Returns: cfgs (dict): updated configs. """ + tmp_user_cfg = {} + for op in user_cfg: + for i, op_name in enumerate(op): + for ops, _ in op_infos_from_cfgs.items(): + if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name: + tmp_user_cfg[(tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])] = ( + user_cfg[op_name] + ) + break + else: + continue + user_cfg = tmp_user_cfg for op_name in user_cfg: inc_op_cfg = user_cfg[op_name] for i, name in enumerate(op_name[0]): @@ -212,6 +224,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover cfgs (dict): dict of configuration """ quantizable_ops = [] + op_name_info = [] # group ops by position for transform-based model detector = TransformerBasedModelBlockPatternDetector(model) detect_result = detector.detect_block() @@ -277,17 +290,30 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover if ipex_op_type in unify_op_type_mapping_ipex: quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + if "class" in ipex_op_type: # "" + op_type = ipex_op_type.split("'")[1] + op_name_info.append((module_fqn, eval(op_type))) + elif "method" in ipex_op_type: # "" + method = ipex_op_type.split("'")[1] + op_type = getattr( + torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method + ) + else: + pass + op_name_info.append((module_fqn, op_type)) else: re_flag = False for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): if re.match(pattern, ipex_op_type): re_flag = True - quantizable_ops.append((tuple(name), unify_op_type)) + quantizable_ops.append(((tuple(name), unify_op_type))) map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn + op_name_info.append((module_fqn, ipex_op_type)) break if not re_flag: - quantizable_ops.append((tuple(name), ipex_op_type)) + quantizable_ops.append(((tuple(name), ipex_op_type))) map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + op_name_info.append((module_fqn, ipex_op_type)) else: op_type = "" for op_name in name: @@ -302,6 +328,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover _op_cfg_id = name[0][2] module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + op_name_info.append((module_fqn, ipex_op_type)) logger.debug("Map op name to fqn: ") logger.debug(map_op_name_to_fqn) @@ -309,7 +336,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover logger.info(attention_block) logger.info("FFN Blocks : ") logger.info(ffn_blocks) - return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name + return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, op_name_info def simple_inference(q_model, example_inputs, iterations=1): diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 29a8177e9be..69ef9d36cc2 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -818,7 +818,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively - model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info @classmethod @@ -923,7 +923,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively - model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info @classmethod diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 518e2240470..493191cae04 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -49,6 +49,21 @@ def test_static_quant_default(self): q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_static_quant_fallback(self): + fp32_model = copy.deepcopy(self.fp32_model) + quant_config = get_default_static_config() + example_inputs = self.input + # fallback by op_type + quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + + # fallback by op_name + quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") @pytest.mark.parametrize( "act_sym, act_algo", From c943b2e43d336924c75b2ce64f9076ec6ffd2359 Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Fri, 19 Apr 2024 14:59:33 +0800 Subject: [PATCH 2/6] minor fix Signed-off-by: Cheng, Zixuan --- .../torch/algorithms/static_quant/static_quant.py | 2 +- neural_compressor/torch/algorithms/static_quant/utility.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index a9169689924..cbc6b69e067 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -82,7 +82,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - # dump_model_op_stats(tune_cfg) + dump_model_op_stats(tune_cfg) return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 76edea7b69c..243f7f80afd 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -306,12 +306,12 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): if re.match(pattern, ipex_op_type): re_flag = True - quantizable_ops.append(((tuple(name), unify_op_type))) + quantizable_ops.append((tuple(name), unify_op_type)) map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn op_name_info.append((module_fqn, ipex_op_type)) break if not re_flag: - quantizable_ops.append(((tuple(name), ipex_op_type))) + quantizable_ops.append((tuple(name), ipex_op_type)) map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn op_name_info.append((module_fqn, ipex_op_type)) else: From e587d850d4907dde8cc4282c3675379dc7bd77b4 Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Fri, 19 Apr 2024 16:22:21 +0800 Subject: [PATCH 3/6] minor fix --- .../torch/algorithms/static_quant/utility.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 243f7f80afd..7d7224d91f7 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -89,11 +89,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ for ops, _ in op_infos_from_cfgs.items(): if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name: tmp_user_cfg[(tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])] = ( - user_cfg[op_name] + user_cfg[op] ) break - else: - continue user_cfg = tmp_user_cfg for op_name in user_cfg: inc_op_cfg = user_cfg[op_name] @@ -298,9 +296,9 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover op_type = getattr( torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method ) + op_name_info.append((module_fqn, op_type)) else: - pass - op_name_info.append((module_fqn, op_type)) + op_name_info.append((module_fqn, op_type)) else: re_flag = False for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): From 32e63c847447a7a6bdea387be8b4cc8391877e16 Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Sun, 21 Apr 2024 17:08:40 +0800 Subject: [PATCH 4/6] fix dump op stats Signed-off-by: Cheng, Zixuan --- neural_compressor/common/base_config.py | 4 +--- .../algorithms/smooth_quant/smooth_quant.py | 7 +++--- .../algorithms/static_quant/static_quant.py | 5 ++-- .../torch/algorithms/static_quant/utility.py | 23 ++++++++++--------- .../torch/quantization/config.py | 2 +- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index d953eaf47d0..05e26d8b05d 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -412,9 +412,7 @@ def to_config_mapping( if op_type in op_type_config_dict: config_mapping[(op_name, op_type)] = op_type_config_dict[op_type] for op_name_pattern in op_name_config_dict: - if isinstance(op_name, str) and re.match(op_name_pattern, op_name): - config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] - elif op_name_pattern == op_name: # TODO: map ipex opname to stock pt op_name + if re.match(op_name_pattern, op_name): config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] return config_mapping diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py index bd26dcdfc3b..c9534b42cd5 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -56,7 +56,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): """ assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs) # check smoothquant folding value recipe_cfgs = tune_cfg.get("recipe_cfgs", None) @@ -121,7 +121,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + dump_model_op_stats(tune_cfg["op"]) return model @@ -161,6 +161,7 @@ def qdq_quantize( # The load_qconf_summary will overwrite the scales used in model but only work in the first call. # Here, we use INC collected scale for Linear and set normal observer instead of SQObserver \ # to make sure calibration works for other ops, like add, bmm. + # update json file in ipex_config_path; map ipex op_name to pt op_name cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True) update_sq_scale(ipex_config_path, smoothquant_scale_info) model.load_qconf_summary(qconf_summary=ipex_config_path) @@ -185,7 +186,7 @@ def qdq_quantize( with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + dump_model_op_stats(tune_cfg["op"]) return model diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index cbc6b69e067..2f4ed042e24 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -52,7 +52,8 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): A quantized model. """ _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs) - cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path + # update json file in ipex_config_path; map ipex op_name to pt op_name + user_cfg = cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) model.eval() # Check save_qconf_summary part is a workaround for IPEX bug. @@ -82,7 +83,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg) + dump_model_op_stats(user_cfg) return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 7d7224d91f7..4657abd46d7 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -16,6 +16,7 @@ import json import os import re +from collections import OrderedDict from typing import Dict, List, Union import torch @@ -66,9 +67,10 @@ def cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name): # pragma: no cover assert cfgs is not None, "No configure for IPEX int8 model..." op_infos = copy.deepcopy(op_infos_from_cfgs) - cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name) + cfgs, user_cfg = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name) with open(ipex_config_path, "w") as write_f: json.dump(cfgs, write_f, indent=4) + return user_cfg def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover @@ -83,14 +85,13 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ Returns: cfgs (dict): updated configs. """ - tmp_user_cfg = {} - for op in user_cfg: + tmp_user_cfg = OrderedDict() + for op in user_cfg: # map ipex op_name to pt op_name for i, op_name in enumerate(op): for ops, _ in op_infos_from_cfgs.items(): if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name: - tmp_user_cfg[(tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])] = ( - user_cfg[op] - ) + ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]]) + tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op] break user_cfg = tmp_user_cfg for op_name in user_cfg: @@ -152,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ else: pass cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg - return cfgs + return cfgs, user_cfg def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover @@ -326,7 +327,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover _op_cfg_id = name[0][2] module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn - op_name_info.append((module_fqn, ipex_op_type)) + op_name_info.append((module_fqn, op_type)) logger.debug("Map op name to fqn: ") logger.debug(map_op_name_to_fqn) @@ -348,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1): q_model(example_inputs) -def dump_model_op_stats(tune_cfg): +def dump_model_op_stats(user_cfg): """This is a function to dump quantizable ops of model to user. Args: - tune_cfg (dict): quantization config + user_cfg (dict): quantization config Returns: None """ res = dict() - for k, v in tune_cfg["op"].items(): + for k, v in user_cfg.items(): op_type_list = k[-1].split("><") op_type = "" for op in op_type_list: diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 69ef9d36cc2..9de2ecb0a94 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -923,7 +923,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively - _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info @classmethod From 9eff3a1f14118c164d248a2369ab9af9f1192b1c Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Sun, 21 Apr 2024 17:10:54 +0800 Subject: [PATCH 5/6] Update smooth_quant.py --- neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py index c9534b42cd5..e49d1bfbab8 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -161,7 +161,6 @@ def qdq_quantize( # The load_qconf_summary will overwrite the scales used in model but only work in the first call. # Here, we use INC collected scale for Linear and set normal observer instead of SQObserver \ # to make sure calibration works for other ops, like add, bmm. - # update json file in ipex_config_path; map ipex op_name to pt op_name cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True) update_sq_scale(ipex_config_path, smoothquant_scale_info) model.load_qconf_summary(qconf_summary=ipex_config_path) From 5522bf91e99e772a4fc3e977d2c5509296b08e0b Mon Sep 17 00:00:00 2001 From: "Cheng, Zixuan" Date: Mon, 22 Apr 2024 16:13:16 +0800 Subject: [PATCH 6/6] update lm_eval version Signed-off-by: Cheng, Zixuan --- .../language-modeling/quantization/llm/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/requirements.txt b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/requirements.txt index 0fac3f8438f..ebea194b93b 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/requirements.txt +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/requirements.txt @@ -9,5 +9,5 @@ wandb einops neural-compressor intel-extension-for-transformers -git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2 +lm-eval peft