diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py index abd8228354e..02329bd9e15 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py @@ -12,6 +12,7 @@ from torch.nn.functional import pad from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer +from neural_compressor.torch.utils import is_hpex_available parser = argparse.ArgumentParser() parser.add_argument( @@ -324,7 +325,11 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): user_model, _ = get_user_model() tokenizer = AutoTokenizer.from_pretrained(args.model) config = AutoConfig.from_pretrained(args.model) - user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), user_model) + user_model = load( + os.path.abspath(os.path.expanduser(args.output_dir)), + user_model, + device="hpu" if is_hpex_available() else "cpu", + ) setattr(user_model, "config", config) else: user_model, tokenizer = get_user_model() @@ -332,14 +337,14 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): if args.accuracy: user_model.eval() - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser eval_args = LMEvalParser( model="hf", user_model=user_model, tokenizer=tokenizer, batch_size=args.batch_size, tasks=args.tasks, - device="cpu", + device="hpu" if is_hpex_available() else "cpu", ) results = evaluate(eval_args) for task_name in args.tasks.split(","): @@ -352,7 +357,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): if args.performance: user_model.eval() - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser import time samples = args.iters * args.batch_size @@ -363,7 +368,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): batch_size=args.batch_size, tasks=args.tasks, limit=samples, - device="cpu", + device="hpu" if is_hpex_available() else "cpu", ) start = time.time() results = evaluate(eval_args) diff --git a/neural_compressor/evaluation/lm_eval/accuracy.py b/neural_compressor/evaluation/lm_eval/accuracy.py index b6d041f539b..403bf48cc03 100644 --- a/neural_compressor/evaluation/lm_eval/accuracy.py +++ b/neural_compressor/evaluation/lm_eval/accuracy.py @@ -36,18 +36,26 @@ from pathlib import Path from typing import Union +import lm_eval import numpy as np -from lm_eval import utils +from lm_eval import evaluator, utils from lm_eval.loggers import WandbLogger from lm_eval.tasks import TaskManager from lm_eval.utils import make_table, simple_parse_args_string -from neural_compressor.evaluation.lm_eval import evaluator -from neural_compressor.evaluation.lm_eval.evaluator import request_caching_arg_to_dict - DEFAULT_RESULTS_FILE = "results.json" +def request_caching_arg_to_dict(cache_requests: str) -> dict: + request_caching_args = { + "cache_requests": cache_requests in {"true", "refresh"}, + "rewrite_requests_cache": cache_requests == "refresh", + "delete_requests_cache": cache_requests == "delete", + } + + return request_caching_args + + def _handle_non_serializable(o): if isinstance(o, np.int64) or isinstance(o, np.int32): return int(o) @@ -143,8 +151,57 @@ def cli_evaluate(args) -> None: request_caching_args = request_caching_arg_to_dict(cache_requests=args.cache_requests) + ### update model with user_model ### + if args.model_args is None: + args.model_args = "" + # replace HFLM. + from .models.huggingface import HFLM + + lm_eval.api.registry.MODEL_REGISTRY["hf-auto"] = HFLM + lm_eval.api.registry.MODEL_REGISTRY["hf"] = HFLM + lm_eval.api.registry.MODEL_REGISTRY["huggingface"] = HFLM + + if args.user_model is not None: + # use tiny model to built lm. + print( + "We use 'pretrained=Muennighoff/tiny-random-bert'" + + "to build `LM` instance, the actually run model is user_model you passed." + ) + lm = lm_eval.api.registry.get_model(args.model).create_from_arg_string( + "pretrained=Muennighoff/tiny-random-bert", + { + "batch_size": args.batch_size, + "max_batch_size": args.max_batch_size, + "device": args.device, + }, + ) + lm._model = args.user_model + if args.tokenizer is not None: + lm.tokenizer = args.tokenizer + else: + assert False, "Please provide tokenizer in evaluation function" + elif isinstance(args.model_args, dict): + lm = lm_eval.api.registry.get_model(args.model).create_from_arg_obj( + args.model_args, + { + "batch_size": args.batch_size, + "max_batch_size": args.max_batch_size, + "device": args.device, + }, + ) + else: + lm = lm_eval.api.registry.get_model(args.model).create_from_arg_string( + args.model_args, + { + "batch_size": args.batch_size, + "max_batch_size": args.max_batch_size, + "device": args.device, + }, + ) + lm.pad_to_buckets = args.pad_to_buckets + results = evaluator.simple_evaluate( - model=args.model, + model=lm, model_args=args.model_args, tasks=task_names, num_fewshot=args.num_fewshot, @@ -163,8 +220,6 @@ def cli_evaluate(args) -> None: random_seed=args.seed[0], numpy_random_seed=args.seed[1], torch_random_seed=args.seed[2], - user_model=args.user_model, # to validate the model in memory, - tokenizer=args.tokenizer, # to use tokenizer in mem, **request_caching_args, ) diff --git a/neural_compressor/evaluation/lm_eval/evaluator.py b/neural_compressor/evaluation/lm_eval/evaluator.py deleted file mode 100644 index 78ed86bc233..00000000000 --- a/neural_compressor/evaluation/lm_eval/evaluator.py +++ /dev/null @@ -1,605 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -import os -import random -import time -from collections import defaultdict -from typing import TYPE_CHECKING, List, Optional, Union - -import lm_eval.api.metrics -import lm_eval.api.registry -import lm_eval.models -import numpy as np -import torch -from lm_eval import utils -from lm_eval.caching.cache import delete_cache -from lm_eval.evaluator_utils import ( - consolidate_results, - get_sample_size, - get_task_list, - prepare_print_tasks, - print_writeout, - run_task_tests, -) -from lm_eval.loggers.utils import add_env_info, get_git_commit_hash -from lm_eval.tasks import TaskManager, get_task_dict -from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string - -if TYPE_CHECKING: - from lm_eval.api.model import LM - from lm_eval.tasks import Task - - -@positional_deprecated -def simple_evaluate( - model, - model_args: Optional[Union[str, dict, object]] = None, - tasks: Optional[List[Union[str, dict, object]]] = None, - num_fewshot: Optional[int] = None, - batch_size: Optional[int] = None, - max_batch_size: Optional[int] = None, - device: Optional[str] = None, - use_cache: Optional[str] = None, - cache_requests: bool = False, - rewrite_requests_cache: bool = False, - delete_requests_cache: bool = False, - limit: Optional[Union[int, float]] = None, - bootstrap_iters: int = 100000, - check_integrity: bool = False, - write_out: bool = False, - log_samples: bool = True, - gen_kwargs: Optional[str] = None, - task_manager: Optional[TaskManager] = None, - verbosity: str = "INFO", - predict_only: bool = False, - random_seed: int = 0, - numpy_random_seed: int = 1234, - torch_random_seed: int = 1234, - user_model: Optional[object] = None, - tokenizer: Optional[object] = None, -): - """Instantiate and evaluate a model on a list of tasks. - - :param model: Union[str, LM] - Name of model or LM object, see lm_eval.models.get_model - :param model_args: Optional[str, dict] - String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object. - Ignored if `model` argument is a LM object. - :param tasks: list[Union[str, dict, Task]] - List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME - if defined and type(task).__name__ otherwise. - :param num_fewshot: int - Number of examples in few-shot context - :param batch_size: int or str, optional - Batch size for model - :param max_batch_size: int, optional - Maximal batch size to try with automatic batch size detection - :param device: str, optional - PyTorch device (e.g. "cpu" or "cuda:0") for running models - :param use_cache: str, optional - A path to a sqlite db file for caching model responses. `None` if not caching. - :param cache_requests: bool, optional - Speed up evaluation by caching the building of dataset requests. `None` if not caching. - :param rewrite_requests_cache: bool, optional - Rewrites all of the request cache if set to `True`. `None` if not desired. - :param delete_requests_cache: bool, optional - Deletes all of the request cache if set to `True`. `None` if not desired. - :param limit: int or float, optional - Limit the number of examples per task (only use this for testing), If <1, - limit is a percentage of the total number of examples. - :param bootstrap_iters: - Number of iterations for bootstrap statistics - :param check_integrity: bool - Whether to run the relevant part of the test suite for the tasks - :param write_out: bool - If True, write out an example document and model input for checking task integrity - :param log_samples: bool - If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis - :param gen_kwargs: str - String arguments for model generation - Ignored for all tasks with loglikelihood output_type - :param predict_only: bool - If true only model outputs will be generated and returned. Metrics will not be evaluated - :param random_seed: int - Random seed for python's random module. If set to None, the seed will not be set. - :param numpy_random_seed: int - Random seed for numpy. If set to None, the seed will not be set. - :param torch_random_seed: int - Random seed for torch. If set to None, the seed will not be set. - - :return - Dictionary of results - """ - - eval_logger.setLevel(getattr(logging, f"{verbosity}")) - start_date = time.time() - - if delete_requests_cache: - eval_logger.info("Deleting requests cache...") - delete_cache() - - seed_message = [] - if random_seed is not None: - # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 - seed_message.append(f"Setting random seed to {random_seed}") - random.seed(random_seed) - - if numpy_random_seed is not None: - seed_message.append(f"Setting numpy seed to {numpy_random_seed}") - np.random.seed(numpy_random_seed) - - if torch_random_seed is not None: - seed_message.append(f"Setting torch manual seed to {torch_random_seed}") - torch.manual_seed(torch_random_seed) - - if seed_message: - eval_logger.info(" | ".join(seed_message)) - - if tasks is None: - tasks = [] - if len(tasks) == 0: - raise ValueError("No tasks specified, or no tasks found. Please verify the task names.") - - if gen_kwargs is not None: - gen_kwargs = simple_parse_args_string(gen_kwargs) - eval_logger.warning( - "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. " - "Ensure 'do_sample=True' for non-greedy decoding!" - ) - if gen_kwargs == "": - gen_kwargs = None - - if isinstance(model, str): - if model_args is None: - model_args = "" - # replace HFLM. - from .models.huggingface import HFLM - - lm_eval.api.registry.MODEL_REGISTRY["hf-auto"] = HFLM - lm_eval.api.registry.MODEL_REGISTRY["hf"] = HFLM - lm_eval.api.registry.MODEL_REGISTRY["huggingface"] = HFLM - - if user_model is not None: - # use tiny model to built lm. - print( - "We use 'pretrained=Muennighoff/tiny-random-bert'" - + "to build `LM` instance, the actually run model is user_model you passed." - ) - lm = lm_eval.api.registry.get_model(model).create_from_arg_string( - "pretrained=Muennighoff/tiny-random-bert", - { - "batch_size": batch_size, - "max_batch_size": max_batch_size, - "device": device, - }, - ) - lm._model = user_model - if tokenizer is not None: - lm.tokenizer = tokenizer - else: - assert False, "Please provide tokenizer in evaluation function" - elif isinstance(model_args, dict): - lm = lm_eval.api.registry.get_model(model).create_from_arg_obj( - model_args, - { - "batch_size": batch_size, - "max_batch_size": max_batch_size, - "device": device, - }, - ) - else: - lm = lm_eval.api.registry.get_model(model).create_from_arg_string( - model_args, - { - "batch_size": batch_size, - "max_batch_size": max_batch_size, - "device": device, - }, - ) - else: - if not isinstance(model, lm_eval.api.model.LM): - raise TypeError - lm = model - - if use_cache is not None: - eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") - lm = lm_eval.api.model.CachingLM( - lm, - use_cache - # each rank receives a different cache db. - # necessary to avoid multiple writes to cache at once - + "_rank" + str(lm.rank) + ".db", - ) - - if task_manager is None: - task_manager = TaskManager(verbosity) - - task_dict = get_task_dict(tasks, task_manager) - for task_name in task_dict.keys(): - task_obj = task_dict[task_name] - if isinstance(task_obj, tuple): - _, task_obj = task_obj - if task_obj is None: - continue - - if task_obj.get_config("output_type") == "generate_until": - if gen_kwargs is not None: - task_obj.set_config(key="generation_kwargs", value=gen_kwargs, update=True) - - if predict_only: - log_samples = True - eval_logger.info(f"Processing {task_name} in output-only mode. Metrics will not be calculated!") - # we have to change the class properties post-hoc. This is pretty hacky. - task_obj.override_metric(metric_name="bypass") - - # override tasks' fewshot values to the provided num_fewshot arg value - # except if tasks have it set to 0 manually in their configs--then we should never overwrite that - if num_fewshot is not None: - if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: - eval_logger.info( - f"num_fewshot has been set to 0 for {task_name} in its config." - + "Manual configuration will be ignored." - ) - else: - eval_logger.warning( - f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" - ) - task_obj.set_config(key="num_fewshot", value=num_fewshot) - else: - # if num_fewshot not provided, and the task does not define a default one, default to 0 - if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: - task_obj.set_config(key="num_fewshot", value=0) - - if check_integrity: - run_task_tests(task_list=tasks) - - results = evaluate( - lm=lm, - task_dict=task_dict, - limit=limit, - cache_requests=cache_requests, - rewrite_requests_cache=rewrite_requests_cache, - bootstrap_iters=bootstrap_iters, - write_out=write_out, - log_samples=log_samples, - verbosity=verbosity, - ) - - if lm.rank == 0: - if isinstance(model, str): - model_name = model - elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"): - model_name = model.config._name_or_path - else: - model_name = type(model).__name__ - - # add info about the model and few shot config - results["config"] = { - "model": model_name, - "model_args": model_args, - "batch_size": batch_size, - "batch_sizes": (list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []), - "device": device, - "use_cache": use_cache, - "limit": limit, - "bootstrap_iters": bootstrap_iters, - "gen_kwargs": gen_kwargs, - } - results["git_hash"] = get_git_commit_hash() - results["date"] = start_date - try: - add_env_info(results) # additional environment info to results - except: - eval_logger.info("get env info failed.") - return results - else: - return None - - -@positional_deprecated -def evaluate( - lm: "LM", - task_dict, - limit: Optional[int] = None, - cache_requests: bool = False, - rewrite_requests_cache: bool = False, - bootstrap_iters: Optional[int] = 100000, - write_out: bool = False, - log_samples: bool = True, - verbosity: str = "INFO", -): - """Instantiate and evaluate a model on a list of tasks. - - :param lm: obj - Language Model - :param task_dict: dict[str, Task] - Dictionary of tasks. Tasks will be taken to have name type(task).config.task . - :param limit: int, optional - Limit the number of examples per task (only use this for testing) - :param bootstrap_iters: - Number of iterations for bootstrap statistics - :param write_out: bool - If True, write out an example document and model input for checking task integrity - :param log_samples: bool - If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis - :return - Dictionary of results - """ - - eval_logger.setLevel(getattr(logging, f"{verbosity}")) - - # tracks all Instances/requests a model must generate output on. - requests = defaultdict(list) - # stores the amount to pad out reqs per req. type so that - # number of fwd passes per distributed rank is equal - padding_requests = defaultdict(int) - - # get lists of group hierarchy and each type of request - task_hierarchy, eval_tasks = get_task_list(task_dict) - if not log_samples: - if not all( - "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() for task_output in eval_tasks - ): - raise ValueError("log_samples must be True for 'bypass' metric-only tasks") - for task_output in eval_tasks: - task: Task = task_output.task - limit = get_sample_size(task, limit) - task.build_all_requests( - limit=limit, - rank=lm.rank, - world_size=lm.world_size, - cache_requests=cache_requests, - rewrite_requests_cache=rewrite_requests_cache, - ) - eval_logger.debug(f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}") - - if write_out: - print_writeout(task) - # aggregate Instances by LM method requested to get output. - for instance in task.instances: - reqtype = instance.request_type - requests[reqtype].append(instance) - - if lm.world_size > 1: - instances_rnk = torch.tensor(len(task._instances), device=lm.device) - gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() - # "multiple_choice" task types dispatch (several) "loglikelihood" request types - reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE - # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks) - numpad = max(gathered_item) - gathered_item[lm.rank] - # todo: may not account for padding in cases like SquadV2 which has multiple req types - padding_requests[reqtype] += numpad - - ### Run LM on inputs, get all outputs ### - # execute each type of request - for reqtype, reqs in requests.items(): - eval_logger.info(f"Running {reqtype} requests") - # create `K` copies of each request `req` based off `K = req.repeats` - cloned_reqs = [] - for req in reqs: - cloned_reqs.extend([req] * req.repeats) - - if (lm.world_size > 1) and (padding_requests[reqtype] > 0): - for _ in range(padding_requests[reqtype]): - cloned_reqs.extend([req] * req.repeats) - - # run requests through model - resps = getattr(lm, reqtype)(cloned_reqs) - - # put responses from model into a list of length K for each request. - for x, req in zip(resps, cloned_reqs): - req.resps.append(x) - - if lm.world_size > 1: - lm.accelerator.wait_for_everyone() - - RANK = lm.rank - WORLD_SIZE = lm.world_size - ### Postprocess outputs ### - # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) - for task_output in eval_tasks: - task = task_output.task - task.apply_filters() - - ### Collect values of metrics on all datapoints ### - # # unpack results and sort back in order and return control to Task - # TODO: make it possible to use a different metric per filter - # Pre-process task.instances to group by doc_id - instances_by_doc_id = defaultdict(list) - for instance in task.instances: - instances_by_doc_id[instance.doc_id].append(instance) - # Sort instances within each group - for instances in instances_by_doc_id.values(): - instances.sort(key=lambda x: x.idx) - # iterate over different filters used - for filter_key in task.instances[0].filtered_resps.keys(): - doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE) - for doc_id, doc in doc_iterator: - requests = instances_by_doc_id[doc_id] - metrics = task.process_results(doc, [req.filtered_resps[filter_key] for req in requests]) - if log_samples: - target = task.doc_to_target(doc) - example = { - "doc_id": doc_id, - "doc": doc, - "target": target, - "arguments": [req.args for req in requests], - "resps": [req.resps for req in requests], - "filtered_resps": [req.filtered_resps[filter_key] for req in requests], - } - example.update(metrics) - task_output.logged_samples.append(example) - for metric, value in metrics.items(): - task_output.sample_metrics[(metric, filter_key)].append(value) - - if WORLD_SIZE > 1: - # if multigpu, then gather data across all ranks to rank 0 - # first gather logged samples across all ranks - for task_output in eval_tasks: - if log_samples: - # for task_name, task_samples in list(samples.items()): - full_samples = [None] * WORLD_SIZE if RANK == 0 else None - torch.distributed.gather_object( - obj=task_output.logged_samples, - object_gather_list=full_samples, - dst=0, - ) - - if RANK == 0: - task_output.logged_samples = list(itertools.chain.from_iterable(full_samples)) - - # then collect metrics across all ranks - for metrics in task_output.sample_metrics: - metric_list = [None] * WORLD_SIZE if RANK == 0 else None - torch.distributed.gather_object( - obj=task_output.sample_metrics[metrics], - object_gather_list=metric_list, - dst=0, - ) - if RANK == 0: - task_output.sample_metrics[metrics] = list(itertools.chain.from_iterable(metric_list)) - - if RANK == 0: - ### Aggregate results over all datapoints ### - # aggregate results ; run bootstrap CIs - for task_output in eval_tasks: - task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) - ( - results, - samples, - configs, - versions, - num_fewshot, - higher_is_better, - ) = consolidate_results(eval_tasks) - - ### Calculate group metrics ### - if bool(results): - for group, task_list in reversed(task_hierarchy.items()): - if len(task_list) == 0: - # task_hierarchy entries are either - # `group_name: [subtask1, subtask2, ...]` - # or `task_name: []`. - # we only want to operate on groups here. - continue - - # collect all higher_is_better values for metrics - # in the group's subtasks. - # TODO: clean this up ; unify with the below metric_list loop? - _higher_is_better = {} - for task in task_list: - for m, h in higher_is_better[task].items(): - if m not in _higher_is_better.keys(): - _higher_is_better[m] = h - if m in _higher_is_better and _higher_is_better[m] is not None and _higher_is_better[m] != h: - eval_logger.warning( - f"Higher_is_better values for metric {m} in group {group} are not consistent." - + "Defaulting to None." - ) - _higher_is_better[m] = None - higher_is_better[group] = _higher_is_better - - # collect all metric keys used by a subtask in the group. - metric_list = list( - { - key - for task in task_list - for key in results[task].keys() - if "_stderr" not in key and key not in ["alias", "samples"] - } - ) - for metric in metric_list: - stderr = "_stderr,".join(metric.split(",")) - - # gather metrics, sizes, and stderrs from subtasks - metrics = [results[task][metric] for task in task_list if metric in results[task]] # TODO: copy? - stderrs = [results[task][stderr] for task in task_list if stderr in results[task]] - sizes = [results[task]["samples"] for task in task_list if metric in results[task]] - - # compute group's pooled metric and stderr - results[group][metric] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes) - # TODO: calculate grouped metric using aggregation fn - if "N/A" in stderrs: - results[group][stderr] = "N/A" - else: - results[group][stderr] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes) - # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility - # To use the old (likely incorrect) variance formula, - # comment out the above and uncomment this line: - # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, - # sizes, metrics=metrics) - - results[group]["samples"] = sum(sizes) - - results_agg = defaultdict(dict) - groups_agg = defaultdict(dict) - all_tasks_list = list(task_hierarchy.keys()) - while True: - add_tasks_list = list(k for k in results_agg.keys()) - left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list))) - if len(left_tasks_list) == 0: - break - - _task_hierarchy = {k: v for k, v in task_hierarchy.items() if k in left_tasks_list} - _results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results) - - results_agg = {**results_agg, **_results_agg} - groups_agg = {**groups_agg, **_groups_agg} - - for group_name, task_list in task_hierarchy.items(): - if task_list: - num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this - - results_dict = { - "results": dict(results_agg.items()), - **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), - "group_subtasks": dict(reversed(task_hierarchy.items())), - "configs": dict(sorted(configs.items())), - "versions": dict(sorted(versions.items())), - "n-shot": dict(sorted(num_fewshot.items())), - "higher_is_better": dict(sorted(higher_is_better.items())), - "n-samples": { - task_output.task_name: { - "original": len(task_output.task.eval_docs), - "effective": min( - limit if limit else len(task_output.task.eval_docs), - len(task_output.task.eval_docs), - ), - } - for task_output in eval_tasks - }, - } - if log_samples: - results_dict["samples"] = dict(samples) - - return results_dict - - else: - return None - - -def request_caching_arg_to_dict(cache_requests: str) -> dict: - request_caching_args = { - "cache_requests": cache_requests in {"true", "refresh"}, - "rewrite_requests_cache": cache_requests == "refresh", - "delete_requests_cache": cache_requests == "delete", - } - - return request_caching_args diff --git a/neural_compressor/evaluation/lm_eval/models/huggingface.py b/neural_compressor/evaluation/lm_eval/models/huggingface.py index 98f219bc1b1..94bf5a44720 100644 --- a/neural_compressor/evaluation/lm_eval/models/huggingface.py +++ b/neural_compressor/evaluation/lm_eval/models/huggingface.py @@ -115,11 +115,12 @@ def __init__( # PEFT and quantization options peft: Optional[str] = None, autogptq: Optional[Union[bool, str]] = False, + pad_to_buckets: Optional[Union[bool]] = False, model_format: Optional[str] = "torch", **kwargs, ) -> None: super().__init__() - + self.pad_to_buckets = pad_to_buckets self.model_format = model_format # optionally: take in an already-initialized transformers.PreTrainedModel if not isinstance(pretrained, str): @@ -166,7 +167,9 @@ def __init__( if not (parallelize or accelerator.num_processes > 1): # use user-passed device device_list = set( - ["cuda", "cpu", "xpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())] + ["mps", "mps:0"] + ["cuda", "cpu", "xpu", "hpu"] + + [f"cuda:{i}" for i in range(torch.cuda.device_count())] + + ["mps", "mps:0"] ) if device and device in device_list: self._device = torch.device(device) @@ -511,7 +514,6 @@ def _create_model( if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("bnb_4bit_compute_dtype", None): model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(model_kwargs["bnb_4bit_compute_dtype"]) - from neural_compressor.transformers import AutoModelForCausalLM if self.model_format == "onnx" and self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if ( @@ -938,11 +940,20 @@ def _model_call(self, inps, attn_mask=None, labels=None): else: output = self.model(inps, torch.ones(inps.shape, dtype=torch.int64)).logits else: + if self.pad_to_buckets: # use buckets to pad inputs + bs, seq_length = inps.shape + padding_length = 0 + buckets = [64, 128, 256, 512, 1024, 2048, 4096, 8192] + bucket_length = [b for b in buckets if b >= seq_length][0] + padding_length = bucket_length - seq_length + inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) output = self.model(inps) if isinstance(output, tuple): output = output[0] else: output = output.logits + if self.pad_to_buckets and padding_length != 0: # use buckets to pad inputs + output = output[:, :-padding_length, :] return output def _model_generate(self, context, max_length, stop, **generation_kwargs): diff --git a/neural_compressor/evaluation/lm_eval/utils.py b/neural_compressor/evaluation/lm_eval/utils.py index ab0dabf5497..da0d9248c42 100644 --- a/neural_compressor/evaluation/lm_eval/utils.py +++ b/neural_compressor/evaluation/lm_eval/utils.py @@ -12,175 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -# def setup_parser() -> argparse.ArgumentParser: -# parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) -# parser.add_argument( -# "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`" -# ) -# parser.add_argument( -# "--tasks", -# "-t", -# default=None, -# type=str, -# metavar="task1,task2", -# help="To get full list of tasks, use the command lm-eval --tasks list", -# ) -# parser.add_argument( -# "--model_args", -# "-a", -# default="", -# type=str, -# help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`", -# ) -# parser.add_argument( -# "--num_fewshot", -# "-f", -# type=int, -# default=None, -# metavar="N", -# help="Number of examples in few-shot context", -# ) -# parser.add_argument( -# "--batch_size", -# "-b", -# type=str, -# default=1, -# metavar="auto|auto:N|N", -# help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", -# ) -# parser.add_argument( -# "--max_batch_size", -# type=int, -# default=None, -# metavar="N", -# help="Maximal batch size to try with --batch_size auto.", -# ) -# parser.add_argument( -# "--device", -# type=str, -# default=None, -# help="Device to use (e.g. cuda, cuda:0, cpu).", -# ) -# parser.add_argument( -# "--output_path", -# "-o", -# default=None, -# type=str, -# metavar="DIR|DIR/file.json", -# help="The path to the output file where the result metrics will be saved. " + \ -# "If the path is a directory and log_samples is true, the results will be saved in the directory." + \ -# " Else the parent directory will be used.", -# ) -# parser.add_argument( -# "--limit", -# "-L", -# type=float, -# default=None, -# metavar="N|0