From 4d89941f2e0c59b0bbd1702191177984371ae0e3 Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Tue, 4 Jun 2024 05:35:30 +0000 Subject: [PATCH] update mx script Signed-off-by: Mengni Wang --- .../quantization/mx/run_clm_no_trainer.py | 91 ++++++++----------- 1 file changed, 38 insertions(+), 53 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx/run_clm_no_trainer.py index db5b08882e0..40bf217c72e 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx/run_clm_no_trainer.py @@ -31,8 +31,9 @@ help="For accuracy measurement only.") parser.add_argument("--save_accuracy_path", default=None, help="Save accuracy results path.") -parser.add_argument("--tasks", type=str, default="lambada_openai", - help="tasks list for accuracy validation") +parser.add_argument("--tasks", nargs="+", default=["lambada_openai"], type=str, + help="tasks list for accuracy validation" +) parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model") args = parser.parse_args() @@ -54,57 +55,41 @@ def get_user_model(): return user_model, tokenizer user_model, tokenizer = get_user_model() -if args.quantize: - from neural_compressor.torch.quantization import MXQuantConfig, quantize - quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq) - user_model = quantize(model=user_model, quant_config=quant_config) +from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert +quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq) +user_model = prepare(model=user_model, quant_config=quant_config) +user_model = convert(model=user_model) +user_model.eval() -if args.accuracy: - user_model.eval() - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser - args = LMEvalParser( - model="hf", - user_model=user_model, - tokenizer=tokenizer, - batch_size=args.batch_size, - tasks=args.tasks, - device="cpu", - ) - results = evaluate(args) - dumped = json.dumps(results, indent=2) - if args.save_accuracy_path: - with open(args.save_accuracy_path, "w") as f: - f.write(dumped) - for task_name in args.tasks: - if task_name == "wikitext": - acc = results["results"][task_name]["word_perplexity"] - else: - acc = results["results"][task_name]["acc"] - print("Accuracy: %.5f" % acc) - print('Batch size = %d' % args.batch_size) +from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser +eval_args = LMEvalParser( + model="hf", + user_model=user_model, + tokenizer=tokenizer, + batch_size=args.batch_size, + tasks=','.join(args.tasks), + device="cpu", +) -if args.performance: - user_model.eval() - from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate - import time - samples = args.iters * args.batch_size - start = time.time() - results = evaluate( - model="hf", - tokenizer=tokenizer, - user_model=user_model, - batch_size=args.batch_size, - tasks=args.tasks, - limit=samples, - ) - end = time.time() - for task_name in args.tasks: - if task_name == "wikitext": - acc = results["results"][task_name]["word_perplexity"] - else: - acc = results["results"][task_name]["acc"] - print("Accuracy: %.5f" % acc) - print('Throughput: %.3f samples/sec' % (samples / (end - start))) - print('Latency: %.3f ms' % ((end - start)*1000 / samples)) - print('Batch size = %d' % args.batch_size) +results = evaluate(eval_args) +dumped = json.dumps(results, indent=2) +if args.save_accuracy_path: + with open(args.save_accuracy_path, "w") as f: + f.write(dumped) + +eval_acc = 0 +for task_name in args.tasks: + if task_name == "wikitext": + print("Accuracy for %s is: %s" % + (task_name, results["results"][task_name]["word_perplexity,none"])) + eval_acc += results["results"][task_name]["word_perplexity,none"] + else: + print("Accuracy for %s is: %s" % + (task_name, results["results"][task_name]["acc,none"])) + eval_acc += results["results"][task_name]["acc,none"] + +if len(args.tasks) != 0: + eval_acc /= len(args.tasks) +print("Accuracy: %.5f" % eval_acc) +print('Batch size = %d' % args.batch_size) \ No newline at end of file