From 8f77f17c81cd9811b66bab3746597f84b3737f3f Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Fri, 31 May 2024 09:26:54 +0800 Subject: [PATCH 01/24] Enhance woq model loading & support hf woq model loading Signed-off-by: yuwenzho --- .../torch/algorithms/weight_only/save_load.py | 476 +++++++++++++++++- .../torch/quantization/load_entry.py | 57 ++- .../weight_only/test_autoround.py | 2 +- .../quantization/weight_only/test_awq.py | 2 +- .../quantization/weight_only/test_gptq.py | 2 +- .../weight_only/test_load_woq_hf_model.py | 17 + .../quantization/weight_only/test_rtn.py | 2 +- .../quantization/weight_only/test_teq.py | 2 +- 8 files changed, 529 insertions(+), 31 deletions(-) create mode 100644 test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 3029cc0eaed..ad5f71f372e 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -26,7 +26,7 @@ def save(model, output_dir="./saved_results"): if not os.path.exists(output_dir): os.mkdir(output_dir) - qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) + qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) # saving process save_config_mapping(model.qconfig, qconfig_file_path) @@ -38,14 +38,476 @@ def save(model, output_dir="./saved_results"): # MethodType 'save' not in state_dict del model.save - torch.save(model, qmodel_file_path) + torch.save(model.state_dict(), qmodel_weight_file_path) - logger.info("Save quantized model to {}.".format(qmodel_file_path)) + logger.info("Save quantized model weight to {}.".format(qmodel_weight_file_path)) logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(output_dir="./saved_results"): - qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) - model = torch.load(qmodel_file_path) - logger.info("Quantized model loading successful.") +def load(model_name_or_path, model=None, format="default", *hf_model_args, **hf_model_kwargs): + if format == "huggingface": + model = _load_hf_woq_model(model_name_or_path, *hf_model_args, **hf_model_kwargs) + logger.info("Quantized huggingface model loading successful.") + return model + elif format == "default": + qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), WEIGHT_NAME) + assert os.path.exists(qmodel_weight_file_path), \ + "Cannot load model weight from path {}".format(qmodel_weight_file_path) + + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), QCONFIG_NAME) + assert os.path.exists(qconfig_file_path), \ + "Cannot load model quantization config from path {}".format(qconfig_file_path) + + assert model is not None, "Can't get origin model. Please pass `model` to load function." + + model = _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, model) + logger.info("Quantized model loading successful.") + return model + else: + raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) + + +def _build_woq_model(model, quantization_config, loaded_state_dict_keys): + """Build weight-only quantization model.""" + from .modules import WeightOnlyLinear, MulLinear + from neural_compressor.torch.utils import set_module + + for name, module in model.named_modules(): + # get quantization config of module + module_name_type = str((name, type(module).__name__)) + module_quantization_config = quantization_config + if module_name_type in quantization_config: + module_quantization_config = quantization_config[module_name_type] + + if isinstance(module, torch.nn.Linear): + # module without qweight means it is not quantized, then skip it + loaded_state_dict_keys_set = set(loaded_state_dict_keys) + if name + ".qweight" not in loaded_state_dict_keys_set and \ + name + ".linear.qweight" not in loaded_state_dict_keys_set: + continue + + # insert MulLinear module + if name + ".linear.qweight" in loaded_state_dict_keys_set: + new_module = MulLinear(module) + set_module(model, name, new_module) + name += ".linear" + + # replace `torch.nn.Linear` with `WeightOnlyLinear` + zp = True if name + ".qzeros" in loaded_state_dict_keys else False + g_idx = True if name + ".g_idx" in loaded_state_dict_keys else False + new_module = WeightOnlyLinear( + module.in_features, + module.out_features, + bits=module_quantization_config.get("bits", 4), + group_size=module_quantization_config.get("group_size", 32), + dtype="int", + zp=zp, + bias=module.bias is not None, + g_idx=g_idx, + use_optimum_format=True, + ) + set_module(model, name, new_module) + + return model + +def _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, origin_model): + qweights = torch.load(qmodel_weight_file_path) + + quantization_config = {} + with open(qconfig_file_path, 'r') as file: + quantization_config = json.load(file) + + model = _build_woq_model(origin_model, quantization_config, qweights.keys()) + model.load_state_dict(qweights, assign=True) + model.eval() + return model + +def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): + import copy + from accelerate.big_modeling import init_empty_weights + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.models.auto.auto_factory import _get_model_class + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_utils import ( + load_state_dict, + _add_variant, + get_checkpoint_shard_files, + no_init_weights, + ) + from transformers.dynamic_module_utils import ( + resolve_trust_remote_code, + get_class_from_dynamic_module, + ) + from transformers.utils import ( + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + is_remote_url, + download_url, + is_safetensors_available, + cached_file, + has_file, + extract_commit_hash, + ContextManagers, + ) + + + # Autofactory + kwargs_orig = copy.deepcopy(kwargs) + trust_remote_code = kwargs.pop("trust_remote_code", None) + subfolder = kwargs.pop("subfolder", "") + variant = kwargs.pop("variant", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + torch_dtype = kwargs.pop("torch_dtype", "auto") + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + resume_download = kwargs.pop("resume_download", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + token = kwargs.pop("token", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + revision = kwargs.pop("revision", "main") + commit_hash = kwargs.pop("_commit_hash", None) + _fast_init = kwargs.pop("_fast_init", True) + use_safetensors = kwargs.pop("use_safetensors", None) + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + quantization_config = config.quantization_config + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + if use_auth_token is not None: + logger.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " + "Please use `token` instead." + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = { + "file_type": "model", + "framework": "pytorch", + "from_auto_class": from_auto_class, + } + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) + # to get the commit hash as soon as possible. + resolved_config_file = cached_file( + pretrained_model_name_or_path, + "config.json", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + has_remote_code = ( + hasattr(config, "auto_map") and AutoModelForCausalLM.__name__ in config.auto_map + ) + + has_local_code = type(config) in AutoModelForCausalLM._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + ) + + if has_remote_code and trust_remote_code: + class_ref = config.auto_map[AutoModelForCausalLM.__name__] + model_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, **kwargs_orig + ) + if os.path.isdir(pretrained_model_name_or_path): + model_class.register_for_auto_class(AutoModelForCausalLM.__name__) + else: + AutoModelForCausalLM.register(config.__class__, model_class, exist_ok=True) + elif type(config) in AutoModelForCausalLM._model_mapping.keys(): + model_class = _get_model_class(config, AutoModelForCausalLM._model_mapping) + + init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts.append(init_empty_weights()) + + with ContextManagers(init_contexts): + model = model_class(config, *model_args, **kwargs) + + is_sharded = False + sharded_metadata = None + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " + f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info( + f"loading weights file {filename} from cache at {resolved_archive_file}" + ) + else: + resolved_archive_file = None + + if is_sharded: + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, + # by checking its first weights entry that is of a floating type + # - we assume all floating dtype weights are of the same dtype + dtype_orig = None + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if ( + hasattr(config, "torch_dtype") + and config.torch_dtype is not None + and config.torch_dtype != "auto" + ): + torch_dtype = config.torch_dtype + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + else: + torch_dtype = torch.float32 + else: + assert ( + False + ), f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + + dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + + model = _build_woq_model(model, quantization_config, loaded_state_dict_keys) + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + return model diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index 35e5fd1208e..b77494de3e8 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -31,28 +31,47 @@ } -def load(output_dir="./saved_results", model=None): - from neural_compressor.common.base_config import ConfigRegistry +def load(model_name_or_path="./saved_results", model=None, format="default", *hf_model_args, **hf_model_kwargs): + """Load quantized model. - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json") - with open(qconfig_file_path, "r") as f: - per_op_qconfig = json.load(f) + Args: + model_name_or_path (str, optional): local path where quantized weights or model are saved + or huggingface model id. Defaults to "./saved_results". + model (torch.nn.Module, optional): original model. Require to pass when loading INC WOQ quantized model + or loading FP8 model. Defaults to None. + format (str, optional): 'defult' for loading INC quantized model. + 'huggingface' now only for loading huggingface WOQ causal language model. Defaults to "default". - if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... - from neural_compressor.torch.algorithms.static_quant import load + Returns: + torch.nn.Module: quantized model + """ + if format == "default": + from neural_compressor.common.base_config import ConfigRegistry + from neural_compressor.torch.algorithms.static_quant import load as static_quant_load + from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load + from neural_compressor.torch.algorithms.habana_fp8 import load as habana_fp8_load - return load(output_dir) - else: - config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) - # select load function - config_object = config_mapping[next(iter(config_mapping))] - if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ - from neural_compressor.torch.algorithms.weight_only.save_load import load + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") + with open(qconfig_file_path, "r") as f: + per_op_qconfig = json.load(f) + + if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... + return static_quant_load(model_name_or_path) + else: + config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) + # select load function + config_object = config_mapping[next(iter(config_mapping))] - return load(output_dir) + if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ + return woq_load(model_name_or_path, model=model, format=format) - model.qconfig = config_mapping - if isinstance(config_object, FP8Config): # FP8 - from neural_compressor.torch.algorithms.habana_fp8 import load + model.qconfig = config_mapping + if isinstance(config_object, FP8Config): # FP8 + return habana_fp8_load(model, model_name_or_path) + elif format == "huggingface": + # now only support load huggingface WOQ causal language model + from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load - return load(model, output_dir) # pylint: disable=E1121 + return woq_load(model_name_or_path, format=format, *hf_model_args, **hf_model_kwargs) + else: + raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 6ff08696a01..246e295fde7 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -148,7 +148,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", model=copy.deepcopy(self.gptj)) loaded_out = loaded_model(self.inp)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 8d2a9472405..5452a02daf1 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -131,7 +131,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index cd48edd8c35..622a7c24539 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -254,7 +254,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py new file mode 100644 index 00000000000..77c161430b0 --- /dev/null +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -0,0 +1,17 @@ +import torch +from transformers import AutoTokenizer +from neural_compressor.torch.utils import accelerator + +device = accelerator.current_device_name() + +class TestHFModelLoad: + def setup_class(self): + self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ" + self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to(device) + + def test_load_hf_woq_model(self): + from neural_compressor.torch.quantization import load + + qmodel = load(self.model_name, format="huggingface") + output = qmodel(self.example_inputs)[0] + assert len(output) > 0, "Not loading the model correctly" diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 53aa19f9424..ab396db1e63 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -290,7 +290,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_teq.py b/test/3x/torch/quantization/weight_only/test_teq.py index 79447054050..04e28afeac1 100644 --- a/test/3x/torch/quantization/weight_only/test_teq.py +++ b/test/3x/torch/quantization/weight_only/test_teq.py @@ -141,7 +141,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", model=copy.deepcopy(self.gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." From 507339aeb783152790e2699060f29da9208c5ac7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 01:28:34 +0000 Subject: [PATCH 02/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/save_load.py | 83 ++++++++----------- .../torch/quantization/load_entry.py | 2 +- .../weight_only/test_load_woq_hf_model.py | 2 + 3 files changed, 36 insertions(+), 51 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index ad5f71f372e..e8e2a29522c 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -51,12 +51,14 @@ def load(model_name_or_path, model=None, format="default", *hf_model_args, **hf_ return model elif format == "default": qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), WEIGHT_NAME) - assert os.path.exists(qmodel_weight_file_path), \ - "Cannot load model weight from path {}".format(qmodel_weight_file_path) + assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( + qmodel_weight_file_path + ) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), QCONFIG_NAME) - assert os.path.exists(qconfig_file_path), \ - "Cannot load model quantization config from path {}".format(qconfig_file_path) + assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( + qconfig_file_path + ) assert model is not None, "Can't get origin model. Please pass `model` to load function." @@ -69,9 +71,10 @@ def load(model_name_or_path, model=None, format="default", *hf_model_args, **hf_ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): """Build weight-only quantization model.""" - from .modules import WeightOnlyLinear, MulLinear from neural_compressor.torch.utils import set_module + from .modules import MulLinear, WeightOnlyLinear + for name, module in model.named_modules(): # get quantization config of module module_name_type = str((name, type(module).__name__)) @@ -82,8 +85,10 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): if isinstance(module, torch.nn.Linear): # module without qweight means it is not quantized, then skip it loaded_state_dict_keys_set = set(loaded_state_dict_keys) - if name + ".qweight" not in loaded_state_dict_keys_set and \ - name + ".linear.qweight" not in loaded_state_dict_keys_set: + if ( + name + ".qweight" not in loaded_state_dict_keys_set + and name + ".linear.qweight" not in loaded_state_dict_keys_set + ): continue # insert MulLinear module @@ -110,11 +115,12 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): return model + def _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, origin_model): qweights = torch.load(qmodel_weight_file_path) quantization_config = {} - with open(qconfig_file_path, 'r') as file: + with open(qconfig_file_path, "r") as file: quantization_config = json.load(file) model = _build_woq_model(origin_model, quantization_config, qweights.keys()) @@ -122,37 +128,30 @@ def _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, origin_model model.eval() return model + def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): import copy + from accelerate.big_modeling import init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM - from transformers.models.auto.auto_factory import _get_model_class from transformers.configuration_utils import PretrainedConfig - from transformers.modeling_utils import ( - load_state_dict, - _add_variant, - get_checkpoint_shard_files, - no_init_weights, - ) - from transformers.dynamic_module_utils import ( - resolve_trust_remote_code, - get_class_from_dynamic_module, - ) + from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code + from transformers.modeling_utils import _add_variant, get_checkpoint_shard_files, load_state_dict, no_init_weights + from transformers.models.auto.auto_factory import _get_model_class from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, - SAFE_WEIGHTS_NAME, - SAFE_WEIGHTS_INDEX_NAME, - is_remote_url, - download_url, - is_safetensors_available, + ContextManagers, cached_file, - has_file, + download_url, extract_commit_hash, - ContextManagers, + has_file, + is_remote_url, + is_safetensors_available, ) - # Autofactory kwargs_orig = copy.deepcopy(kwargs) trust_remote_code = kwargs.pop("trust_remote_code", None) @@ -188,9 +187,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): "Please use `token` instead." ) if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token user_agent = { @@ -226,9 +223,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): else: commit_hash = getattr(config, "_commit_hash", None) - has_remote_code = ( - hasattr(config, "auto_map") and AutoModelForCausalLM.__name__ in config.auto_map - ) + has_remote_code = hasattr(config, "auto_map") and AutoModelForCausalLM.__name__ in config.auto_map has_local_code = type(config) in AutoModelForCausalLM._model_mapping.keys() trust_remote_code = resolve_trust_remote_code( @@ -240,9 +235,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): if has_remote_code and trust_remote_code: class_ref = config.auto_map[AutoModelForCausalLM.__name__] - model_class = get_class_from_dynamic_module( - class_ref, pretrained_model_name_or_path, **kwargs_orig - ) + model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) if os.path.isdir(pretrained_model_name_or_path): model_class.register_for_auto_class(AutoModelForCausalLM.__name__) else: @@ -389,9 +382,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): "proxies": proxies, "token": token, } - if variant is not None and has_file( - pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs - ): + if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" @@ -419,9 +410,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): logger.info(f"loading weights file {archive_file}") resolved_archive_file = archive_file else: - logger.info( - f"loading weights file {filename} from cache at {resolved_archive_file}" - ) + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") else: resolved_archive_file = None @@ -451,11 +440,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": - if ( - hasattr(config, "torch_dtype") - and config.torch_dtype is not None - and config.torch_dtype != "auto" - ): + if hasattr(config, "torch_dtype") and config.torch_dtype is not None and config.torch_dtype != "auto": torch_dtype = config.torch_dtype else: if is_sharded and "dtype" in sharded_metadata: @@ -463,9 +448,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): else: torch_dtype = torch.float32 else: - assert ( - False - ), f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' dtype_orig = model_class._set_default_torch_dtype(torch_dtype) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index b77494de3e8..aa0daa39b18 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -47,9 +47,9 @@ def load(model_name_or_path="./saved_results", model=None, format="default", *hf """ if format == "default": from neural_compressor.common.base_config import ConfigRegistry + from neural_compressor.torch.algorithms.habana_fp8 import load as habana_fp8_load from neural_compressor.torch.algorithms.static_quant import load as static_quant_load from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load - from neural_compressor.torch.algorithms.habana_fp8 import load as habana_fp8_load qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") with open(qconfig_file_path, "r") as f: diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py index 77c161430b0..9e3b00cbeb3 100644 --- a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -1,9 +1,11 @@ import torch from transformers import AutoTokenizer + from neural_compressor.torch.utils import accelerator device = accelerator.current_device_name() + class TestHFModelLoad: def setup_class(self): self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ" From 2a7f5df5d089367ee2b6f206ba5c00c6075b2212 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Fri, 31 May 2024 08:15:15 +0000 Subject: [PATCH 03/24] enhance code & fix bug Signed-off-by: yuwenzho --- .../torch/algorithms/weight_only/save_load.py | 14 +++++++++++++- .../torch/quantization/load_entry.py | 19 +++++++++++-------- .../weight_only/test_load_woq_hf_model.py | 2 -- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index e8e2a29522c..bc008f6b8b3 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -130,8 +130,20 @@ def _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, origin_model def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): + # check required package + try: + import transformers + except ImportError: + logger.error("`transformers` package is required for loading hugginface weight-only quantization model.") + + try: + import accelerate + except ImportError: + logger.error("`accelerate` package is required for loading hugginface weight-only quantization model.") + + # below codes are refer to load_low_bit function in + # https://github.com/intel/intel-extension-for-transformers/blob/v1.4.2/intel_extension_for_transformers/transformers/modeling/modeling_auto.py#L1464 import copy - from accelerate.big_modeling import init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM from transformers.configuration_utils import PretrainedConfig diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index aa0daa39b18..bd16e7d0c05 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -47,31 +47,34 @@ def load(model_name_or_path="./saved_results", model=None, format="default", *hf """ if format == "default": from neural_compressor.common.base_config import ConfigRegistry - from neural_compressor.torch.algorithms.habana_fp8 import load as habana_fp8_load - from neural_compressor.torch.algorithms.static_quant import load as static_quant_load - from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") with open(qconfig_file_path, "r") as f: per_op_qconfig = json.load(f) if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... - return static_quant_load(model_name_or_path) + from neural_compressor.torch.algorithms.static_quant import load + + return load(model_name_or_path) else: config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) # select load function config_object = config_mapping[next(iter(config_mapping))] if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ - return woq_load(model_name_or_path, model=model, format=format) + from neural_compressor.torch.algorithms.weight_only.save_load import load + + return load(model_name_or_path, model=model, format=format) model.qconfig = config_mapping if isinstance(config_object, FP8Config): # FP8 - return habana_fp8_load(model, model_name_or_path) + from neural_compressor.torch.algorithms.habana_fp8 import load + + return load(model, model_name_or_path) elif format == "huggingface": # now only support load huggingface WOQ causal language model - from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load + from neural_compressor.torch.algorithms.weight_only.save_load import load - return woq_load(model_name_or_path, format=format, *hf_model_args, **hf_model_kwargs) + return load(model_name_or_path, format=format, *hf_model_args, **hf_model_kwargs) else: raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py index 9e3b00cbeb3..9f9c468d551 100644 --- a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -1,11 +1,9 @@ import torch -from transformers import AutoTokenizer from neural_compressor.torch.utils import accelerator device = accelerator.current_device_name() - class TestHFModelLoad: def setup_class(self): self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ" From 76edd2b2563ba88c2f694c9ba8255596f2cea060 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 08:17:25 +0000 Subject: [PATCH 04/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/algorithms/weight_only/save_load.py | 1 + test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py | 1 + 2 files changed, 2 insertions(+) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index bc008f6b8b3..536f7f6c875 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -144,6 +144,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): # below codes are refer to load_low_bit function in # https://github.com/intel/intel-extension-for-transformers/blob/v1.4.2/intel_extension_for_transformers/transformers/modeling/modeling_auto.py#L1464 import copy + from accelerate.big_modeling import init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM from transformers.configuration_utils import PretrainedConfig diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py index 9f9c468d551..77a1964db82 100644 --- a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -4,6 +4,7 @@ device = accelerator.current_device_name() + class TestHFModelLoad: def setup_class(self): self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ" From 9124e04adcfcaf8d7d85c227b8f0a83c08f96905 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Sun, 2 Jun 2024 20:26:23 -0700 Subject: [PATCH 05/24] update load API usage Signed-off-by: yuwenzho --- .../language-modeling/quantization/llm/run_clm_no_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 2556820284a..ac06fd13775 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -366,7 +366,7 @@ def run_fn(model): user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(user_model) user_model = convert(user_model) - + user_model.save(args.output_dir) @@ -376,9 +376,10 @@ def run_fn(model): print("load int8 model") from neural_compressor.torch.quantization import load + user_model, _ = get_user_model() tokenizer = AutoTokenizer.from_pretrained(args.model) config = AutoConfig.from_pretrained(args.model) - user_model = load(os.path.abspath(os.path.expanduser(args.output_dir))) + user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), model=user_model) setattr(user_model, "config", config) else: user_model, tokenizer = get_user_model() From 091676c3830b8b5f16f2db53178d4e8e0b1d6da1 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Mon, 3 Jun 2024 07:20:32 +0000 Subject: [PATCH 06/24] fix bug Signed-off-by: yuwenzho --- neural_compressor/torch/algorithms/weight_only/save_load.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 536f7f6c875..04d4bddd9ce 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -80,7 +80,7 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): module_name_type = str((name, type(module).__name__)) module_quantization_config = quantization_config if module_name_type in quantization_config: - module_quantization_config = quantization_config[module_name_type] + module_quantization_config = [config for config in quantization_config[module_name_type].values()][0] if isinstance(module, torch.nn.Linear): # module without qweight means it is not quantized, then skip it @@ -98,8 +98,8 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): name += ".linear" # replace `torch.nn.Linear` with `WeightOnlyLinear` - zp = True if name + ".qzeros" in loaded_state_dict_keys else False - g_idx = True if name + ".g_idx" in loaded_state_dict_keys else False + zp = True if name + ".qzeros" in loaded_state_dict_keys_set else False + g_idx = True if name + ".g_idx" in loaded_state_dict_keys_set else False new_module = WeightOnlyLinear( module.in_features, module.out_features, From 6a0eec44be5f2d2f0287e5429ec61c460e614f7d Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Mon, 3 Jun 2024 16:20:30 +0300 Subject: [PATCH 07/24] fix bug Signed-off-by: yuwenzho --- neural_compressor/torch/algorithms/weight_only/save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 04d4bddd9ce..0369193af9f 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -103,9 +103,9 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): new_module = WeightOnlyLinear( module.in_features, module.out_features, + dtype=module_quantization_config.get("dtype", "int"), bits=module_quantization_config.get("bits", 4), group_size=module_quantization_config.get("group_size", 32), - dtype="int", zp=zp, bias=module.bias is not None, g_idx=g_idx, From d25763d0280ddd2daeec128540e9835fe760c6f2 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 5 Jun 2024 03:11:34 +0000 Subject: [PATCH 08/24] enhance code Signed-off-by: yuwenzho --- .../torch/algorithms/weight_only/save_load.py | 52 ++++++++++++------- .../torch/algorithms/weight_only/utility.py | 26 ++++++++++ .../torch/quantization/load_entry.py | 9 ++-- neural_compressor/torch/utils/constants.py | 8 +++ 4 files changed, 73 insertions(+), 22 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 0369193af9f..ec1c10a6e94 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -16,11 +16,12 @@ import json import os +import re import torch from neural_compressor.common.utils import load_config_mapping, save_config_mapping -from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger +from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger, LoadFormat def save(model, output_dir="./saved_results"): @@ -44,12 +45,12 @@ def save(model, output_dir="./saved_results"): logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(model_name_or_path, model=None, format="default", *hf_model_args, **hf_model_kwargs): - if format == "huggingface": +def load(model_name_or_path, model=None, format=LoadFormat.DEFAULT, *hf_model_args, **hf_model_kwargs): + if format == LoadFormat.HUGGINGFACE: model = _load_hf_woq_model(model_name_or_path, *hf_model_args, **hf_model_kwargs) logger.info("Quantized huggingface model loading successful.") return model - elif format == "default": + elif format == LoadFormat.DEFAULT: qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), WEIGHT_NAME) assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( qmodel_weight_file_path @@ -72,15 +73,19 @@ def load(model_name_or_path, model=None, format="default", *hf_model_args, **hf_ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): """Build weight-only quantization model.""" from neural_compressor.torch.utils import set_module - - from .modules import MulLinear, WeightOnlyLinear + from .modules import MulLinear for name, module in model.named_modules(): + _is_autoround = False # get quantization config of module - module_name_type = str((name, type(module).__name__)) module_quantization_config = quantization_config - if module_name_type in quantization_config: - module_quantization_config = [config for config in quantization_config[module_name_type].values()][0] + # pattern will map (module_name, moduele_type) + pattern = fr"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" + for q_config_key, q_config_value in quantization_config.items(): + if re.search(pattern, q_config_key): + if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround": + _is_autoround = True + module_quantization_config = [config for config in q_config_value.values()][0] if isinstance(module, torch.nn.Linear): # module without qweight means it is not quantized, then skip it @@ -100,16 +105,31 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): # replace `torch.nn.Linear` with `WeightOnlyLinear` zp = True if name + ".qzeros" in loaded_state_dict_keys_set else False g_idx = True if name + ".g_idx" in loaded_state_dict_keys_set else False - new_module = WeightOnlyLinear( + + kwargs = {} + if _is_autoround: + from auto_round.export.export_to_itrex.model_wrapper import ( + WeightOnlyLinear as AutoRoundWeightOnlyLinear + ) + from .utility import convert_dtype_str2torch + WeightOnlyLinearClass = AutoRoundWeightOnlyLinear + kwargs["groupsize"] = module_quantization_config.get("group_size", 32) + kwargs["scale_dtype"] = convert_dtype_str2torch(module_quantization_config.get("scale_dtype", "fp16")) + else: + from .modules import WeightOnlyLinear as INCWeightOnlyLinear + WeightOnlyLinearClass = INCWeightOnlyLinear + kwargs["group_size"] = module_quantization_config.get("group_size", 32) + kwargs["g_idx"] = g_idx + + new_module = WeightOnlyLinearClass( module.in_features, module.out_features, dtype=module_quantization_config.get("dtype", "int"), bits=module_quantization_config.get("bits", 4), - group_size=module_quantization_config.get("group_size", 32), zp=zp, bias=module.bias is not None, - g_idx=g_idx, use_optimum_format=True, + **kwargs, ) set_module(model, name, new_module) @@ -133,13 +153,9 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): # check required package try: import transformers - except ImportError: - logger.error("`transformers` package is required for loading hugginface weight-only quantization model.") - - try: import accelerate - except ImportError: - logger.error("`accelerate` package is required for loading hugginface weight-only quantization model.") + except ImportError as e: + raise e # below codes are refer to load_low_bit function in # https://github.com/intel/intel-extension-for-transformers/blob/v1.4.2/intel_extension_for_transformers/transformers/modeling/modeling_auto.py#L1464 diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 31cbe3bc342..4bb66626814 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1101,3 +1101,29 @@ def forward(self, *args, **kwargs): with torch.no_grad(): self.args_list.append(args) self.kwargs_list.append(kwargs) + + +def convert_dtype_str2torch(str_dtype): + """Converts a string dtype to its corresponding PyTorch dtype. + + Args: + str_dtype (str): The string representation of the dtype. + + Returns: + torch.dtype: The PyTorch dtype. + + Raises: + AssertionError: If the input str_dtype is unsupported. + """ + if isinstance(str_dtype, torch.dtype) or str_dtype is None: + return str_dtype + if str_dtype == "int8": + return torch.int8 + elif str_dtype == "fp32" or str_dtype == "float32" or str_dtype == "auto": + return torch.float + elif str_dtype == "fp16" or str_dtype == "float16": + return torch.float16 + elif str_dtype == "bf16" or str_dtype == "bfloat16": + return torch.bfloat16 + else: + assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index bd16e7d0c05..daa4eec90fb 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -25,6 +25,7 @@ RTNConfig, TEQConfig, ) +from neural_compressor.torch.utils import LoadFormat config_name_mapping = { FP8_QUANT: FP8Config, @@ -45,7 +46,7 @@ def load(model_name_or_path="./saved_results", model=None, format="default", *hf Returns: torch.nn.Module: quantized model """ - if format == "default": + if format == LoadFormat.DEFAULT.value: from neural_compressor.common.base_config import ConfigRegistry qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") @@ -64,17 +65,17 @@ def load(model_name_or_path="./saved_results", model=None, format="default", *hf if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model_name_or_path, model=model, format=format) + return load(model_name_or_path, model=model, format=LoadFormat.DEFAULT) model.qconfig = config_mapping if isinstance(config_object, FP8Config): # FP8 from neural_compressor.torch.algorithms.habana_fp8 import load return load(model, model_name_or_path) - elif format == "huggingface": + elif format == LoadFormat.HUGGINGFACE.value: # now only support load huggingface WOQ causal language model from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model_name_or_path, format=format, *hf_model_args, **hf_model_kwargs) + return load(model_name_or_path, format=LoadFormat.HUGGINGFACE, *hf_model_args, **hf_model_kwargs) else: raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 429851e311b..c4960dd4ea7 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -53,3 +53,11 @@ PT2E_STATIC_QUANT = "pt2e_static_quant" PT2E_DYNAMIC_QUANT = "pt2e_dynamic_quant" + + +# load format name +from enum import Enum + +class LoadFormat(Enum): + DEFAULT = "default" + HUGGINGFACE = "huggingface" From 219fb548405bd1c60e09c6d26ea18f0c4786a770 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 03:15:53 +0000 Subject: [PATCH 09/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/save_load.py | 12 ++++++++---- neural_compressor/torch/utils/constants.py | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index ec1c10a6e94..5d1d8af7e4a 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -21,7 +21,7 @@ import torch from neural_compressor.common.utils import load_config_mapping, save_config_mapping -from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger, LoadFormat +from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, LoadFormat, logger def save(model, output_dir="./saved_results"): @@ -73,6 +73,7 @@ def load(model_name_or_path, model=None, format=LoadFormat.DEFAULT, *hf_model_ar def _build_woq_model(model, quantization_config, loaded_state_dict_keys): """Build weight-only quantization model.""" from neural_compressor.torch.utils import set_module + from .modules import MulLinear for name, module in model.named_modules(): @@ -80,7 +81,7 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): # get quantization config of module module_quantization_config = quantization_config # pattern will map (module_name, moduele_type) - pattern = fr"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" + pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" for q_config_key, q_config_value in quantization_config.items(): if re.search(pattern, q_config_key): if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround": @@ -109,14 +110,17 @@ def _build_woq_model(model, quantization_config, loaded_state_dict_keys): kwargs = {} if _is_autoround: from auto_round.export.export_to_itrex.model_wrapper import ( - WeightOnlyLinear as AutoRoundWeightOnlyLinear + WeightOnlyLinear as AutoRoundWeightOnlyLinear, ) + from .utility import convert_dtype_str2torch + WeightOnlyLinearClass = AutoRoundWeightOnlyLinear kwargs["groupsize"] = module_quantization_config.get("group_size", 32) kwargs["scale_dtype"] = convert_dtype_str2torch(module_quantization_config.get("scale_dtype", "fp16")) else: from .modules import WeightOnlyLinear as INCWeightOnlyLinear + WeightOnlyLinearClass = INCWeightOnlyLinear kwargs["group_size"] = module_quantization_config.get("group_size", 32) kwargs["g_idx"] = g_idx @@ -152,8 +156,8 @@ def _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, origin_model def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): # check required package try: - import transformers import accelerate + import transformers except ImportError as e: raise e diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index c4960dd4ea7..a655a70b8ed 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -58,6 +58,7 @@ # load format name from enum import Enum + class LoadFormat(Enum): DEFAULT = "default" HUGGINGFACE = "huggingface" From 3c6b1f5476672f4db0fa7f175718b8c1f542676f Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 5 Jun 2024 06:06:22 +0000 Subject: [PATCH 10/24] enhance coverage Signed-off-by: yuwenzho --- .../torch/algorithms/weight_only/save_load.py | 24 +++++++++---------- .../weight_only/test_woq_utility.py | 17 +++++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 5d1d8af7e4a..b4fdb2e2617 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -158,7 +158,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): try: import accelerate import transformers - except ImportError as e: + except ImportError as e: # pragma: no cover raise e # below codes are refer to load_low_bit function in @@ -214,7 +214,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): if use_safetensors is None and not is_safetensors_available(): use_safetensors = False - if use_auth_token is not None: + if use_auth_token is not None: # pragma: no cover logger.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " "Please use `token` instead." @@ -228,14 +228,14 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): "framework": "pytorch", "from_auto_class": from_auto_class, } - if from_pipeline is not None: + if from_pipeline is not None: # pragma: no cover user_agent["using_pipeline"] = from_pipeline - if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover config._attn_implementation = kwarg_attn_imp if commit_hash is None: - if not isinstance(config, PretrainedConfig): + if not isinstance(config, PretrainedConfig): # pragma: no cover # We make a call to the config file first (which may be absent) # to get the commit hash as soon as possible. resolved_config_file = cached_file( @@ -266,7 +266,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): has_remote_code, ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code: # pragma: no cover class_ref = config.auto_map[AutoModelForCausalLM.__name__] model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) if os.path.isdir(pretrained_model_name_or_path): @@ -285,7 +285,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): is_sharded = False sharded_metadata = None - if pretrained_model_name_or_path is not None: + if pretrained_model_name_or_path is not None: # pragma: no cover pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: @@ -444,10 +444,10 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): resolved_archive_file = archive_file else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - else: + else: # pragma: no cover resolved_archive_file = None - if is_sharded: + if is_sharded: # pragma: no cover # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, @@ -475,17 +475,17 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): if torch_dtype == "auto": if hasattr(config, "torch_dtype") and config.torch_dtype is not None and config.torch_dtype != "auto": torch_dtype = config.torch_dtype - else: + else: # pragma: no cover if is_sharded and "dtype" in sharded_metadata: torch_dtype = sharded_metadata["dtype"] else: torch_dtype = torch.float32 - else: + else: # pragma: no cover assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' dtype_orig = model_class._set_default_torch_dtype(torch_dtype) - if is_sharded: + if is_sharded: # pragma: no cover loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: # Time to load the checkpoint diff --git a/test/3x/torch/algorithms/weight_only/test_woq_utility.py b/test/3x/torch/algorithms/weight_only/test_woq_utility.py index f672ec0ac1c..2e19e253d50 100644 --- a/test/3x/torch/algorithms/weight_only/test_woq_utility.py +++ b/test/3x/torch/algorithms/weight_only/test_woq_utility.py @@ -11,3 +11,20 @@ def test_quant_tensor_id(shape): output = quant_tensor(input) id2 = id(output) assert id1 == id2, "quant_tensor function is an in-place operator" + +def test_convert_dtype_str2torch(): + from neural_compressor.torch.algorithms.weight_only.utility import convert_dtype_str2torch + + # Test for supported dtypes + assert convert_dtype_str2torch("int8") == torch.int8 + assert convert_dtype_str2torch("fp32") == torch.float + assert convert_dtype_str2torch("float32") == torch.float + assert convert_dtype_str2torch("auto") == torch.float + assert convert_dtype_str2torch("fp16") == torch.float16 + assert convert_dtype_str2torch("float16") == torch.float16 + assert convert_dtype_str2torch("bf16") == torch.bfloat16 + assert convert_dtype_str2torch("bfloat16") == torch.bfloat16 + + # Test for unsupported dtypes + with pytest.raises(AssertionError): + convert_dtype_str2torch("int16") From d1cdd6214801b9740a9b73c83b51052eb4ad05cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 06:08:01 +0000 Subject: [PATCH 11/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/save_load.py | 24 +++++++++---------- .../weight_only/test_woq_utility.py | 1 + 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index b4fdb2e2617..5e79076956e 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -158,7 +158,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): try: import accelerate import transformers - except ImportError as e: # pragma: no cover + except ImportError as e: # pragma: no cover raise e # below codes are refer to load_low_bit function in @@ -214,7 +214,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): if use_safetensors is None and not is_safetensors_available(): use_safetensors = False - if use_auth_token is not None: # pragma: no cover + if use_auth_token is not None: # pragma: no cover logger.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " "Please use `token` instead." @@ -228,14 +228,14 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): "framework": "pytorch", "from_auto_class": from_auto_class, } - if from_pipeline is not None: # pragma: no cover + if from_pipeline is not None: # pragma: no cover user_agent["using_pipeline"] = from_pipeline - if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover config._attn_implementation = kwarg_attn_imp if commit_hash is None: - if not isinstance(config, PretrainedConfig): # pragma: no cover + if not isinstance(config, PretrainedConfig): # pragma: no cover # We make a call to the config file first (which may be absent) # to get the commit hash as soon as possible. resolved_config_file = cached_file( @@ -266,7 +266,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): has_remote_code, ) - if has_remote_code and trust_remote_code: # pragma: no cover + if has_remote_code and trust_remote_code: # pragma: no cover class_ref = config.auto_map[AutoModelForCausalLM.__name__] model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) if os.path.isdir(pretrained_model_name_or_path): @@ -285,7 +285,7 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): is_sharded = False sharded_metadata = None - if pretrained_model_name_or_path is not None: # pragma: no cover + if pretrained_model_name_or_path is not None: # pragma: no cover pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: @@ -444,10 +444,10 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): resolved_archive_file = archive_file else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - else: # pragma: no cover + else: # pragma: no cover resolved_archive_file = None - if is_sharded: # pragma: no cover + if is_sharded: # pragma: no cover # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, @@ -475,17 +475,17 @@ def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): if torch_dtype == "auto": if hasattr(config, "torch_dtype") and config.torch_dtype is not None and config.torch_dtype != "auto": torch_dtype = config.torch_dtype - else: # pragma: no cover + else: # pragma: no cover if is_sharded and "dtype" in sharded_metadata: torch_dtype = sharded_metadata["dtype"] else: torch_dtype = torch.float32 - else: # pragma: no cover + else: # pragma: no cover assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' dtype_orig = model_class._set_default_torch_dtype(torch_dtype) - if is_sharded: # pragma: no cover + if is_sharded: # pragma: no cover loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: # Time to load the checkpoint diff --git a/test/3x/torch/algorithms/weight_only/test_woq_utility.py b/test/3x/torch/algorithms/weight_only/test_woq_utility.py index 2e19e253d50..712ba52d889 100644 --- a/test/3x/torch/algorithms/weight_only/test_woq_utility.py +++ b/test/3x/torch/algorithms/weight_only/test_woq_utility.py @@ -12,6 +12,7 @@ def test_quant_tensor_id(shape): id2 = id(output) assert id1 == id2, "quant_tensor function is an in-place operator" + def test_convert_dtype_str2torch(): from neural_compressor.torch.algorithms.weight_only.utility import convert_dtype_str2torch From 11ae6fd04f7f9a30240146433b05cbb71741b634 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 6 Jun 2024 03:25:09 +0000 Subject: [PATCH 12/24] enhance code Signed-off-by: yuwenzho --- docs/3x/PT_WeightOnlyQuant.md | 14 +++++----- .../quantization/habana_fp8/run_llm.py | 14 +++++----- .../quantization/llm/run_clm_no_trainer.py | 2 +- .../torch/algorithms/weight_only/save_load.py | 27 +++++++++++++++--- .../torch/quantization/load_entry.py | 28 +++++++++++-------- .../torch/quantization/habana_fp8/test_fp8.py | 2 +- .../torch/quantization/test_smooth_quant.py | 2 +- .../weight_only/test_autoround.py | 2 +- .../quantization/weight_only/test_awq.py | 2 +- .../quantization/weight_only/test_gptq.py | 2 +- .../weight_only/test_load_woq_hf_model.py | 2 +- .../quantization/weight_only/test_rtn.py | 2 +- .../quantization/weight_only/test_teq.py | 2 +- 13 files changed, 63 insertions(+), 38 deletions(-) diff --git a/docs/3x/PT_WeightOnlyQuant.md b/docs/3x/PT_WeightOnlyQuant.md index e7e5c543215..08e585ddc28 100644 --- a/docs/3x/PT_WeightOnlyQuant.md +++ b/docs/3x/PT_WeightOnlyQuant.md @@ -31,13 +31,13 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz ## Supported Matrix -| Algorithms/Backend | PyTorch eager mode | +| Algorithms/Backend | PyTorch eager mode | |--------------|----------| | RTN | ✔ | | GPTQ | ✔ | | AutoRound| ✔ | | AWQ | ✔ | -| TEQ | ✔ | +| TEQ | ✔ | | HQQ | ✔ | > **RTN:** A quantification method that we can think of very intuitively. It does not require additional datasets and is a very fast quantization method. Generally speaking, RTN will convert the weight into a uniformly distributed integer data type, but some algorithms, such as Qlora, propose a non-uniform NF4 data type and prove its theoretical optimality. @@ -64,8 +64,8 @@ WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./P | bits (int)| [1, ..., 8] | | group_size (int)| [-1, 1, ..., $C_{in}$] | | use_sym (bool)| [True, False] | -| use_double_quant (bool) | [True, False] | -| double_quant_dtype (str) | ['int'] | +| use_double_quant (bool) | [True, False] | +| double_quant_dtype (str) | ['int'] | | double_quant_bits (int) | [1, ..., bits] | | double_quant_use_sym (bool) | [True, False] | | double_quant_group_size (int) | [-1, 1, ..., $C_{in}$] | @@ -98,7 +98,7 @@ model = convert(model) #### GPTQ | gptq_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| -| use_mse_search (bool) | Enables mean squared error (MSE) search | False +| use_mse_search (bool) | Enables mean squared error (MSE) search | False | use_layer_wise (bool) | Enables quantize model per layer | False | | model_path (str) | Model path that is used to load state_dict per layer | | | use_double_quant (bool) | Enables double quantization | False | @@ -120,7 +120,7 @@ model = convert(model) #### AutoRound | autoround_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| -| enable_full_range (bool) | Whether to enable full range quantization | False +| enable_full_range (bool) | Whether to enable full range quantization | False | batch_size (int) | Batch size for training | 8 | | lr_scheduler | The learning rate scheduler to be used | None | | enable_quanted_input (bool) | Whether to use quantized input data | True | @@ -251,7 +251,7 @@ from neural_compressor.torch.quantization import load orig_model = YOURMODEL() loaded_model = load( - "saved_results", model=orig_model + model=orig_model, checkpoint_dir="saved_results" ) # Please note that the model parameter passes the original model. ``` diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index 5cd0f046aba..2d8e61479f3 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -63,7 +63,7 @@ parser.add_argument("--calib_iters", default=100, type=int, help="calibration iters.") parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \ - type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", + type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", "rte", "openbookqa", "lambada_standard", "wikitext"], help="tasks list for accuracy validation") parser.add_argument("--limit", default=None, type=int, @@ -117,10 +117,10 @@ for examples in calib_dataset: calib_data.append( tokenizer( - examples["text"], - return_tensors="pt", - max_length=64, - padding="max_length", + examples["text"], + return_tensors="pt", + max_length=64, + padding="max_length", truncation=True ) ) @@ -143,7 +143,7 @@ def calib_func(model): if args.load: from neural_compressor.torch.quantization import load - user_model = load("saved_results", user_model) + user_model = load(model=user_model, checkpoint_dir="saved_results") if args.approach in ["dynamic", "static"] or args.load: @@ -154,7 +154,7 @@ def calib_func(model): -# If torch.matmul and torch.bmm are not replaced by INC module, +# If torch.matmul and torch.bmm are not replaced by INC module, # Below codes can make torch.matmul and torch.bmm run on fp8 by injection. if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']: def replace_torch_mm_bmm(): diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index ac06fd13775..b26d98d4266 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -379,7 +379,7 @@ def run_fn(model): user_model, _ = get_user_model() tokenizer = AutoTokenizer.from_pretrained(args.model) config = AutoConfig.from_pretrained(args.model) - user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), model=user_model) + user_model = load(model=user_model, checkpoint_dir=os.path.abspath(os.path.expanduser(args.output_dir))) setattr(user_model, "config", config) else: user_model, tokenizer = get_user_model() diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 5e79076956e..fdce080bde8 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -45,18 +45,37 @@ def save(model, output_dir="./saved_results"): logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(model_name_or_path, model=None, format=LoadFormat.DEFAULT, *hf_model_args, **hf_model_kwargs): +def load(model, checkpoint_dir=None, format=LoadFormat.DEFAULT, *hf_model_args, **hf_model_kwargs): + """Load quantized weight-only quantization model. + + 1. Load INC weight-only quantized model in local. + 2. Load HuggingFace weight-only quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + + Args: + model (Union[torch.nn.Module], str): torch model or hugginface model id. + if 'format' is set to 'huggingface', it means the model_name_or_path of huggingface weight-only quantized model . + if 'format' is set to 'default', it means the fp32 model and the 'checkpoint_dir' + parameter should not be None. it coworks with 'checkpoint_dir' parameter to load INC + weight-only quantized model in local. + checkpoint_dir (str, optional): local path where quantized weights are saved. + Only needed if 'format' is set to 'default'. + format (str, optional): 'defult' for loading INC weight-only quantized model. + 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". + + Returns: + torch.nn.Module: quantized model + """ if format == LoadFormat.HUGGINGFACE: - model = _load_hf_woq_model(model_name_or_path, *hf_model_args, **hf_model_kwargs) + model = _load_hf_woq_model(model, *hf_model_args, **hf_model_kwargs) logger.info("Quantized huggingface model loading successful.") return model elif format == LoadFormat.DEFAULT: - qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), WEIGHT_NAME) + qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), WEIGHT_NAME) assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( qmodel_weight_file_path ) - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), QCONFIG_NAME) + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), QCONFIG_NAME) assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( qconfig_file_path ) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index e5d098f2921..55076a4efee 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -32,16 +32,22 @@ } -def load(model_name_or_path="./saved_results", model=None, format="default", *hf_model_args, **hf_model_kwargs): +def load(model=None, checkpoint_dir="./saved_results", format="default", *hf_model_args, **hf_model_kwargs): """Load quantized model. + 1. Load INC quantized model in local. + 2. Load HuggingFace quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + Args: - model_name_or_path (str, optional): local path where quantized weights or model are saved - or huggingface model id. Defaults to "./saved_results". - model (torch.nn.Module, optional): original model, suggest to use empty tensor. - Require to pass when loading INC WOQ quantized model or loading FP8 model. Defaults to None. + model (Union[torch.nn.Module], str): torch model or hugginface model_name_or_path. + if 'format' is set to 'huggingface', it means the huggingface model_name_or_path. + if 'format' is set to 'default', it means the fp32 model and the 'checkpoint_dir' + parameter should not be None. it coworks with 'checkpoint_dir' parameter to load INC + quantized model in local. + checkpoint_dir (str, optional): local path where quantized weights or model are saved. + Only needed if 'format' is set to 'default'. format (str, optional): 'defult' for loading INC quantized model. - 'huggingface' now only for loading huggingface WOQ causal language model. Defaults to "default". + 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". Returns: torch.nn.Module: quantized model @@ -49,14 +55,14 @@ def load(model_name_or_path="./saved_results", model=None, format="default", *hf if format == LoadFormat.DEFAULT.value: from neural_compressor.common.base_config import ConfigRegistry - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), "qconfig.json") with open(qconfig_file_path, "r") as f: per_op_qconfig = json.load(f) if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... from neural_compressor.torch.algorithms.static_quant import load - return load(model_name_or_path) + return load(checkpoint_dir) else: config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) # select load function @@ -65,17 +71,17 @@ def load(model_name_or_path="./saved_results", model=None, format="default", *hf if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model_name_or_path, model=model, format=LoadFormat.DEFAULT) + return load(model=model, checkpoint_dir=checkpoint_dir, format=LoadFormat.DEFAULT) model.qconfig = config_mapping if isinstance(config_object, FP8Config): # FP8 from neural_compressor.torch.algorithms.habana_fp8 import load - return load(model, model_name_or_path) + return load(model, checkpoint_dir) elif format == LoadFormat.HUGGINGFACE.value: # now only support load huggingface WOQ causal language model from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model_name_or_path, format=LoadFormat.HUGGINGFACE, *hf_model_args, **hf_model_kwargs) + return load(model=model, format=LoadFormat.HUGGINGFACE, *hf_model_args, **hf_model_kwargs) else: raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/test/3x/torch/quantization/habana_fp8/test_fp8.py b/test/3x/torch/quantization/habana_fp8/test_fp8.py index 8fafc302f65..3f73a7a8dbc 100644 --- a/test/3x/torch/quantization/habana_fp8/test_fp8.py +++ b/test/3x/torch/quantization/habana_fp8/test_fp8.py @@ -153,7 +153,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load m = copy.deepcopy(self.model) - m = load("saved_results", m) + m = load(model=m, checkpoint_dir="saved_results") recovered_out = m(inp) assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check." assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py index 1d49ad27763..180786e954c 100644 --- a/test/3x/torch/quantization/test_smooth_quant.py +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -150,7 +150,7 @@ def test_sq_save_load(self): from neural_compressor.torch.quantization import load # load using saved model - loaded_model = load("saved_results") + loaded_model = load(checkpoint_dir="saved_results") loaded_out = loaded_model(example_inputs) # set a big atol to avoid random issue assert torch.allclose(inc_out, loaded_out, atol=2e-02), "Unexpected result. Please double check." diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 8208853f851..cb281d698bd 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -117,7 +117,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results", model=copy.deepcopy(self.gptj)) + loaded_model = load(model=copy.deepcopy(self.gptj), checkpoint_dir="saved_results") loaded_out = loaded_model(self.inp)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 5452a02daf1..78959b35e30 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -131,7 +131,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj)) + loaded_model = load(model=copy.deepcopy(self.tiny_gptj), checkpoint_dir="saved_results") loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 622a7c24539..cc3846137d1 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -254,7 +254,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj)) + loaded_model = load(model=copy.deepcopy(self.tiny_gptj), checkpoint_dir="saved_results") loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py index 77a1964db82..709609bd58e 100644 --- a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -13,6 +13,6 @@ def setup_class(self): def test_load_hf_woq_model(self): from neural_compressor.torch.quantization import load - qmodel = load(self.model_name, format="huggingface") + qmodel = load(model=self.model_name, format="huggingface") output = qmodel(self.example_inputs)[0] assert len(output) > 0, "Not loading the model correctly" diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index ab396db1e63..db06e41890f 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -290,7 +290,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj)) + loaded_model = load(model=copy.deepcopy(self.tiny_gptj), checkpoint_dir="saved_results") loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_teq.py b/test/3x/torch/quantization/weight_only/test_teq.py index 04e28afeac1..80128d354b3 100644 --- a/test/3x/torch/quantization/weight_only/test_teq.py +++ b/test/3x/torch/quantization/weight_only/test_teq.py @@ -141,7 +141,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results", model=copy.deepcopy(self.gptj)) + loaded_model = load(model=copy.deepcopy(self.gptj), checkpoint_dir="saved_results") loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." From 67fedf0b36c52ad8b053391d278c5ea91219f6df Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 6 Jun 2024 03:37:27 +0000 Subject: [PATCH 13/24] fix pylint Signed-off-by: yuwenzho --- neural_compressor/torch/algorithms/weight_only/save_load.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index fdce080bde8..60603d5c22d 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -49,11 +49,13 @@ def load(model, checkpoint_dir=None, format=LoadFormat.DEFAULT, *hf_model_args, """Load quantized weight-only quantization model. 1. Load INC weight-only quantized model in local. - 2. Load HuggingFace weight-only quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + 2. Load HuggingFace weight-only quantized model, + including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. Args: model (Union[torch.nn.Module], str): torch model or hugginface model id. - if 'format' is set to 'huggingface', it means the model_name_or_path of huggingface weight-only quantized model . + if 'format' is set to 'huggingface', it means the model_name_or_path of + huggingface weight-only quantized model. if 'format' is set to 'default', it means the fp32 model and the 'checkpoint_dir' parameter should not be None. it coworks with 'checkpoint_dir' parameter to load INC weight-only quantized model in local. From 124678c88629349f044276b6f287d3d8eba81ea8 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 11 Jun 2024 03:27:01 +0000 Subject: [PATCH 14/24] enhance load API Signed-off-by: yuwenzho --- .../quantization/llm/run_clm_no_trainer.py | 2 +- .../torch/algorithms/weight_only/save_load.py | 32 ++++++++++-------- .../torch/quantization/load_entry.py | 33 +++++++++++-------- .../torch/quantization/habana_fp8/test_fp8.py | 2 +- .../torch/quantization/test_smooth_quant.py | 2 +- .../weight_only/test_autoround.py | 2 +- .../quantization/weight_only/test_awq.py | 2 +- .../quantization/weight_only/test_gptq.py | 2 +- .../weight_only/test_load_woq_hf_model.py | 2 +- .../quantization/weight_only/test_rtn.py | 2 +- .../quantization/weight_only/test_teq.py | 2 +- 11 files changed, 47 insertions(+), 36 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index b26d98d4266..f592e9db8e3 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -379,7 +379,7 @@ def run_fn(model): user_model, _ = get_user_model() tokenizer = AutoTokenizer.from_pretrained(args.model) config = AutoConfig.from_pretrained(args.model) - user_model = load(model=user_model, checkpoint_dir=os.path.abspath(os.path.expanduser(args.output_dir))) + user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), user_model) setattr(user_model, "config", config) else: user_model, tokenizer = get_user_model() diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 60603d5c22d..5dc794e7792 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -45,7 +45,7 @@ def save(model, output_dir="./saved_results"): logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(model, checkpoint_dir=None, format=LoadFormat.DEFAULT, *hf_model_args, **hf_model_kwargs): +def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device='cpu', *model_args, **kwargs): """Load quantized weight-only quantization model. 1. Load INC weight-only quantized model in local. @@ -53,38 +53,42 @@ def load(model, checkpoint_dir=None, format=LoadFormat.DEFAULT, *hf_model_args, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. Args: - model (Union[torch.nn.Module], str): torch model or hugginface model id. - if 'format' is set to 'huggingface', it means the model_name_or_path of - huggingface weight-only quantized model. - if 'format' is set to 'default', it means the fp32 model and the 'checkpoint_dir' - parameter should not be None. it coworks with 'checkpoint_dir' parameter to load INC + model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path. + If 'format' is set to 'huggingface', it means the huggingface model_name_or_path. + If 'format' is set to 'default', it means the 'checkpoint_dir'. + Parameter should not be None. it coworks with 'original_model' parameter to load INC weight-only quantized model in local. - checkpoint_dir (str, optional): local path where quantized weights are saved. - Only needed if 'format' is set to 'default'. + original_model (torch.nn.module, optional): original model before quantization. + Needed if 'format' is set to 'default' and not TorchScript model.Defaults to None. format (str, optional): 'defult' for loading INC weight-only quantized model. 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". - + model_args (sequence of positional arguments, optional): + all remaining positional arguments for loading huggingface models. + will be passed to the huggingface model's `__init__` method. + kwargs (remaining dictionary of keyword arguments, optional): + remaining dictionary of keyword arguments for loading huggingface models. + will be passed to the huggingface model's `__init__` method, such as 'trust_remote_code', 'revision'. Returns: torch.nn.Module: quantized model """ if format == LoadFormat.HUGGINGFACE: - model = _load_hf_woq_model(model, *hf_model_args, **hf_model_kwargs) + model = _load_hf_woq_model(model_name_or_path, *model_args, **kwargs) logger.info("Quantized huggingface model loading successful.") return model elif format == LoadFormat.DEFAULT: - qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), WEIGHT_NAME) + qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), WEIGHT_NAME) assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( qmodel_weight_file_path ) - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), QCONFIG_NAME) + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), QCONFIG_NAME) assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( qconfig_file_path ) - assert model is not None, "Can't get origin model. Please pass `model` to load function." + assert original_model is not None, "Can't get original model. Please pass `original_model` to load function." - model = _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, model) + model = _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, original_model) logger.info("Quantized model loading successful.") return model else: diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index 55076a4efee..2349e2034f5 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -32,37 +32,44 @@ } -def load(model=None, checkpoint_dir="./saved_results", format="default", *hf_model_args, **hf_model_kwargs): +def load(model_name_or_path, original_model=None, format='default', device='cpu', *model_args, **kwargs): """Load quantized model. 1. Load INC quantized model in local. 2. Load HuggingFace quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. Args: - model (Union[torch.nn.Module], str): torch model or hugginface model_name_or_path. - if 'format' is set to 'huggingface', it means the huggingface model_name_or_path. - if 'format' is set to 'default', it means the fp32 model and the 'checkpoint_dir' - parameter should not be None. it coworks with 'checkpoint_dir' parameter to load INC + model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path. + If 'format' is set to 'huggingface', it means the huggingface model_name_or_path. + If 'format' is set to 'default', it means the 'checkpoint_dir'. + Parameter should not be None. it coworks with 'original_model' parameter to load INC quantized model in local. - checkpoint_dir (str, optional): local path where quantized weights or model are saved. - Only needed if 'format' is set to 'default'. + original_model (torch.nn.module, optional): original model before quantization. + Needed if 'format' is set to 'default' and not TorchScript model.Defaults to None. format (str, optional): 'defult' for loading INC quantized model. 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". - + device (str, optional): 'cpu', 'hpu' or 'cuda'. specify the device the model will be loaded to. + model_args (sequence of positional arguments, optional): + all remaining positional arguments for loading huggingface models. + Will be passed to the huggingface model's `__init__` method. + kwargs (remaining dictionary of keyword arguments, optional): + remaining dictionary of keyword arguments for loading huggingface models. + Will be passed to the huggingface model's `__init__` method, such as 'trust_remote_code', 'revision'. Returns: torch.nn.Module: quantized model """ + # TODO: When loading WOQ model, use different WeightOnlyLinear module according to device. if format == LoadFormat.DEFAULT.value: from neural_compressor.common.base_config import ConfigRegistry - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), "qconfig.json") + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") with open(qconfig_file_path, "r") as f: per_op_qconfig = json.load(f) if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... from neural_compressor.torch.algorithms.static_quant import load - return load(checkpoint_dir) + return load(model_name_or_path) else: config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) # select load function @@ -71,17 +78,17 @@ def load(model=None, checkpoint_dir="./saved_results", format="default", *hf_mod if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model=model, checkpoint_dir=checkpoint_dir, format=LoadFormat.DEFAULT) + return load(model_name_or_path, original_model, format=LoadFormat.DEFAULT) model.qconfig = config_mapping if isinstance(config_object, FP8Config): # FP8 from neural_compressor.torch.algorithms.habana_fp8 import load - return load(model, checkpoint_dir) + return load(model_name_or_path, original_model) elif format == LoadFormat.HUGGINGFACE.value: # now only support load huggingface WOQ causal language model from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model=model, format=LoadFormat.HUGGINGFACE, *hf_model_args, **hf_model_kwargs) + return load(model_name_or_path, format=LoadFormat.HUGGINGFACE, *model_args, **kwargs) else: raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/test/3x/torch/quantization/habana_fp8/test_fp8.py b/test/3x/torch/quantization/habana_fp8/test_fp8.py index 3f73a7a8dbc..8fafc302f65 100644 --- a/test/3x/torch/quantization/habana_fp8/test_fp8.py +++ b/test/3x/torch/quantization/habana_fp8/test_fp8.py @@ -153,7 +153,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load m = copy.deepcopy(self.model) - m = load(model=m, checkpoint_dir="saved_results") + m = load("saved_results", m) recovered_out = m(inp) assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check." assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py index 180786e954c..1d49ad27763 100644 --- a/test/3x/torch/quantization/test_smooth_quant.py +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -150,7 +150,7 @@ def test_sq_save_load(self): from neural_compressor.torch.quantization import load # load using saved model - loaded_model = load(checkpoint_dir="saved_results") + loaded_model = load("saved_results") loaded_out = loaded_model(example_inputs) # set a big atol to avoid random issue assert torch.allclose(inc_out, loaded_out, atol=2e-02), "Unexpected result. Please double check." diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index cb281d698bd..f1539b072b7 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -117,7 +117,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load(model=copy.deepcopy(self.gptj), checkpoint_dir="saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.gptj)) loaded_out = loaded_model(self.inp)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 78959b35e30..0a6d5f4f687 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -131,7 +131,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load(model=copy.deepcopy(self.tiny_gptj), checkpoint_dir="saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index cc3846137d1..be408af2564 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -254,7 +254,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load(model=copy.deepcopy(self.tiny_gptj), checkpoint_dir="saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py index 709609bd58e..887332a2ad0 100644 --- a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -13,6 +13,6 @@ def setup_class(self): def test_load_hf_woq_model(self): from neural_compressor.torch.quantization import load - qmodel = load(model=self.model_name, format="huggingface") + qmodel = load(model_name_or_path=self.model_name, format="huggingface") output = qmodel(self.example_inputs)[0] assert len(output) > 0, "Not loading the model correctly" diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index db06e41890f..a2f63480a4e 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -290,7 +290,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load(model=copy.deepcopy(self.tiny_gptj), checkpoint_dir="saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_teq.py b/test/3x/torch/quantization/weight_only/test_teq.py index 80128d354b3..9f4df1c4226 100644 --- a/test/3x/torch/quantization/weight_only/test_teq.py +++ b/test/3x/torch/quantization/weight_only/test_teq.py @@ -141,7 +141,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load(model=copy.deepcopy(self.gptj), checkpoint_dir="saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." From f3ba356200746e3196adde97b7e90554e4b5a15a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 03:31:49 +0000 Subject: [PATCH 15/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/algorithms/weight_only/save_load.py | 2 +- neural_compressor/torch/quantization/load_entry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 9f63bc285e4..8912789a4f2 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -44,7 +44,7 @@ def save(model, output_dir="./saved_results"): logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device='cpu', *model_args, **kwargs): +def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", *model_args, **kwargs): """Load quantized weight-only quantization model. 1. Load INC weight-only quantized model in local. diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index 2349e2034f5..ed09bb4bbab 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -32,7 +32,7 @@ } -def load(model_name_or_path, original_model=None, format='default', device='cpu', *model_args, **kwargs): +def load(model_name_or_path, original_model=None, format="default", device="cpu", *model_args, **kwargs): """Load quantized model. 1. Load INC quantized model in local. From eb95cff851367608d1c2fdacf26dc7e440abde50 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 11 Jun 2024 03:38:51 +0000 Subject: [PATCH 16/24] enhance docstring Signed-off-by: yuwenzho --- .../torch/quantization/load_entry.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index ed09bb4bbab..10429a845b3 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -32,12 +32,29 @@ } -def load(model_name_or_path, original_model=None, format="default", device="cpu", *model_args, **kwargs): +def load(model_name_or_path, original_model=None, format='default', device='cpu', *model_args, **kwargs): """Load quantized model. 1. Load INC quantized model in local. 2. Load HuggingFace quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + case 1: WOQ + # huggingface model + from neural_compressor.torch.quantization import load + load(model_name_or_path=model_name_or_path) + + # local model + from neural_compressor.torch.quantization import load + load(model_name_or_path="saved_results", original_model=fp32_model) + + case 2: INT8/FP8 + from neural_compressor.torch.quantization import load + load(model_name_or_path='saved_result', original_model=fp32_model) + + case 3: TorchScript (IPEX) + from neural_compressor.torch.quantization import load + load(model_name_or_path='saved_result') + Args: model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path. If 'format' is set to 'huggingface', it means the huggingface model_name_or_path. @@ -80,7 +97,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" return load(model_name_or_path, original_model, format=LoadFormat.DEFAULT) - model.qconfig = config_mapping + original_model.qconfig = config_mapping if isinstance(config_object, FP8Config): # FP8 from neural_compressor.torch.algorithms.habana_fp8 import load From 097e1a178a12657f457d3ca8ce751e8dd5382cc4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 03:40:29 +0000 Subject: [PATCH 17/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/quantization/load_entry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index 10429a845b3..c5e0f4df369 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -32,7 +32,7 @@ } -def load(model_name_or_path, original_model=None, format='default', device='cpu', *model_args, **kwargs): +def load(model_name_or_path, original_model=None, format="default", device="cpu", *model_args, **kwargs): """Load quantized model. 1. Load INC quantized model in local. From 92ad0e5a377a3c79e8e6b91b44f21bf16e6fa0dc Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 11 Jun 2024 05:48:24 +0000 Subject: [PATCH 18/24] update load API usage Signed-off-by: yuwenzho --- .../language-modeling/quantization/habana_fp8/run_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index 2d8e61479f3..e77ef2c6a33 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -143,7 +143,7 @@ def calib_func(model): if args.load: from neural_compressor.torch.quantization import load - user_model = load(model=user_model, checkpoint_dir="saved_results") + user_model = load("saved_results", user_model) if args.approach in ["dynamic", "static"] or args.load: From fba5711c5ff9c9f67356f10ed441b7a3c62f9f30 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 11 Jun 2024 05:51:31 +0000 Subject: [PATCH 19/24] update load API usage Signed-off-by: yuwenzho --- docs/3x/PT_WeightOnlyQuant.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/3x/PT_WeightOnlyQuant.md b/docs/3x/PT_WeightOnlyQuant.md index 08e585ddc28..37cc934592a 100644 --- a/docs/3x/PT_WeightOnlyQuant.md +++ b/docs/3x/PT_WeightOnlyQuant.md @@ -251,8 +251,8 @@ from neural_compressor.torch.quantization import load orig_model = YOURMODEL() loaded_model = load( - model=orig_model, checkpoint_dir="saved_results" -) # Please note that the model parameter passes the original model. + "saved_results", original_model=orig_model +) # Please note that the original_model parameter passes the original model. ``` From 939a261dd4554ea6e45dc2ae23586ab7eda4ed68 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 12 Jun 2024 02:18:25 +0000 Subject: [PATCH 20/24] enhance code Signed-off-by: yuwenzho --- .../torch/algorithms/weight_only/save_load.py | 984 ++++++++++-------- .../torch/quantization/load_entry.py | 14 +- .../weight_only/test_load_woq_hf_model.py | 8 +- 3 files changed, 542 insertions(+), 464 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 8912789a4f2..aac5290a424 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -17,7 +17,7 @@ import json import os import re - +import copy import torch from neural_compressor.common.utils import load_config_mapping, save_config_mapping @@ -44,7 +44,7 @@ def save(model, output_dir="./saved_results"): logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", *model_args, **kwargs): +def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", **kwargs): """Load quantized weight-only quantization model. 1. Load INC weight-only quantized model in local. @@ -61,493 +61,567 @@ def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, dev Needed if 'format' is set to 'default' and not TorchScript model.Defaults to None. format (str, optional): 'defult' for loading INC weight-only quantized model. 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". - model_args (sequence of positional arguments, optional): - all remaining positional arguments for loading huggingface models. - will be passed to the huggingface model's `__init__` method. kwargs (remaining dictionary of keyword arguments, optional): remaining dictionary of keyword arguments for loading huggingface models. will be passed to the huggingface model's `__init__` method, such as 'trust_remote_code', 'revision'. Returns: torch.nn.Module: quantized model """ - if format == LoadFormat.HUGGINGFACE: - model = _load_hf_woq_model(model_name_or_path, *model_args, **kwargs) - logger.info("Quantized huggingface model loading successful.") + model_loader = WOQModelLoader(model_name_or_path, original_model, format, device="cpu", **kwargs) + model = model_loader.load_woq_model() + return model + + +class WOQModelLoader: + def __init__(self, model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", **kwargs): + # TODO: When loading WOQ model, use different WeightOnlyLinear module according to device. + self.model_name_or_path = model_name_or_path + self.original_model = original_model + self.format = format + self.device = device + self.kwargs = kwargs + self.quantization_config = {} + self.loaded_state_dict_keys = {} + + def load_woq_model(self): + if self.format == LoadFormat.HUGGINGFACE: + model = self.load_hf_format_woq_model() + logger.info("Quantized huggingface model loading successful.") + elif self.format == LoadFormat.DEFAULT: + qmodel_weight_file_path = os.path.join( + os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME) + assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( + qmodel_weight_file_path + ) + + qconfig_file_path = os.path.join( + os.path.abspath(os.path.expanduser(self.model_name_or_path)), QCONFIG_NAME) + assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( + qconfig_file_path + ) + + assert self.original_model is not None, \ + "Can't get original model. Please pass `original_model` to load function." + + model = self.load_inc_format_woq_model(qmodel_weight_file_path, qconfig_file_path) + logger.info("Quantized model loading successful.") + else: + raise ValueError( + f"`format` in load function can only be 'huggingface' or 'default', but get {self.format}") + return model - elif format == LoadFormat.DEFAULT: - qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), WEIGHT_NAME) - assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( - qmodel_weight_file_path - ) - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), QCONFIG_NAME) - assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( - qconfig_file_path - ) + def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path): + qweights = torch.load(qmodel_weight_file_path) + self.loaded_state_dict_keys = qweights.keys() - assert original_model is not None, "Can't get original model. Please pass `original_model` to load function." + with open(qconfig_file_path, "r") as file: + self.quantization_config = json.load(file) - model = _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, original_model) - logger.info("Quantized model loading successful.") + model = self._build_woq_model() + model.load_state_dict(qweights, assign=True) + model.eval() return model - else: - raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) - - -def _build_woq_model(model, quantization_config, loaded_state_dict_keys): - """Build weight-only quantization model.""" - from neural_compressor.torch.utils import set_module - - from .modules import MulLinear - - for name, module in model.named_modules(): - _is_autoround = False - # get quantization config of module - module_quantization_config = quantization_config - # pattern will map (module_name, moduele_type) - pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" - for q_config_key, q_config_value in quantization_config.items(): - if re.search(pattern, q_config_key): - if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround": - _is_autoround = True - module_quantization_config = [config for config in q_config_value.values()][0] - - if isinstance(module, torch.nn.Linear): - # module without qweight means it is not quantized, then skip it - loaded_state_dict_keys_set = set(loaded_state_dict_keys) - if ( - name + ".qweight" not in loaded_state_dict_keys_set - and name + ".linear.qweight" not in loaded_state_dict_keys_set - ): - continue - - # insert MulLinear module - if name + ".linear.qweight" in loaded_state_dict_keys_set: - new_module = MulLinear(module) - set_module(model, name, new_module) - name += ".linear" - - # replace `torch.nn.Linear` with `WeightOnlyLinear` - zp = True if name + ".qzeros" in loaded_state_dict_keys_set else False - g_idx = True if name + ".g_idx" in loaded_state_dict_keys_set else False - - kwargs = {} - if _is_autoround: - from auto_round.export.export_to_itrex.model_wrapper import ( - WeightOnlyLinear as AutoRoundWeightOnlyLinear, - ) - from .utility import convert_dtype_str2torch + def load_hf_format_woq_model(self): + # check required package + self._check_required_packages() - WeightOnlyLinearClass = AutoRoundWeightOnlyLinear - kwargs["groupsize"] = module_quantization_config.get("group_size", 32) - kwargs["scale_dtype"] = convert_dtype_str2torch(module_quantization_config.get("scale_dtype", "fp16")) - else: - from .modules import WeightOnlyLinear as INCWeightOnlyLinear - - WeightOnlyLinearClass = INCWeightOnlyLinear - kwargs["group_size"] = module_quantization_config.get("group_size", 32) - kwargs["g_idx"] = g_idx - - new_module = WeightOnlyLinearClass( - module.in_features, - module.out_features, - dtype=module_quantization_config.get("dtype", "int"), - bits=module_quantization_config.get("bits", 4), - zp=zp, - bias=module.bias is not None, - use_optimum_format=True, - **kwargs, - ) - set_module(model, name, new_module) + # get model_class and config + model_class, config = self._get_model_class_and_config() + self.quantization_config = config.quantization_config - return model + # get loaded_state_dict_keys + self.loaded_state_dict_keys = self._get_loaded_state_dict_keys(config) + # initiate the huggingface model + self.original_model = self._init_hf_model(model_class, config) -def _load_inc_woq_model(qmodel_weight_file_path, qconfig_file_path, origin_model): - qweights = torch.load(qmodel_weight_file_path) + # build weight-only quantization model with WeightOnlyLinear module + model = self._build_woq_model() - quantization_config = {} - with open(qconfig_file_path, "r") as file: - quantization_config = json.load(file) + # load quantized weight to woq model + model = self._load_pretrained_weight(model, model_class) - model = _build_woq_model(origin_model, quantization_config, qweights.keys()) - model.load_state_dict(qweights, assign=True) - model.eval() - return model + return model + def _build_woq_model(self): + """Build weight-only quantization model.""" + from neural_compressor.torch.utils import set_module + + from .modules import MulLinear + + for name, module in self.original_model.named_modules(): + _is_autoround = False + # get quantization config of module + module_quantization_config = self.quantization_config + # pattern will map (module_name, moduele_type) + pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" + for q_config_key, q_config_value in self.quantization_config.items(): + if re.search(pattern, q_config_key): + if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround": + _is_autoround = True + module_quantization_config = [config for config in q_config_value.values()][0] + + if isinstance(module, torch.nn.Linear): + # module without qweight means it is not quantized, then skip it + loaded_state_dict_keys_set = set(self.loaded_state_dict_keys) + if ( + name + ".qweight" not in loaded_state_dict_keys_set + and name + ".linear.qweight" not in loaded_state_dict_keys_set + ): + continue + + # insert MulLinear module + if name + ".linear.qweight" in loaded_state_dict_keys_set: + new_module = MulLinear(module) + set_module(self.original_model, name, new_module) + name += ".linear" + + # replace `torch.nn.Linear` with `WeightOnlyLinear` + zp = True if name + ".qzeros" in loaded_state_dict_keys_set else False + g_idx = True if name + ".g_idx" in loaded_state_dict_keys_set else False + + kwargs = {} + if _is_autoround: + from auto_round.export.export_to_itrex.model_wrapper import ( + WeightOnlyLinear as AutoRoundWeightOnlyLinear, + ) -def _load_hf_woq_model(pretrained_model_name_or_path, *model_args, **kwargs): - # check required package - try: - import accelerate - import transformers - except ImportError as e: # pragma: no cover - raise e - - # below codes are refer to load_low_bit function in - # https://github.com/intel/intel-extension-for-transformers/blob/v1.4.2/intel_extension_for_transformers/transformers/modeling/modeling_auto.py#L1464 - import copy - - from accelerate.big_modeling import init_empty_weights - from transformers import AutoConfig, AutoModelForCausalLM - from transformers.configuration_utils import PretrainedConfig - from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code - from transformers.modeling_utils import _add_variant, get_checkpoint_shard_files, load_state_dict, no_init_weights - from transformers.models.auto.auto_factory import _get_model_class - from transformers.utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - WEIGHTS_INDEX_NAME, - WEIGHTS_NAME, - ContextManagers, - cached_file, - download_url, - extract_commit_hash, - has_file, - is_remote_url, - is_safetensors_available, - ) - - # Autofactory - kwargs_orig = copy.deepcopy(kwargs) - trust_remote_code = kwargs.pop("trust_remote_code", None) - subfolder = kwargs.pop("subfolder", "") - variant = kwargs.pop("variant", None) - offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) - torch_dtype = kwargs.pop("torch_dtype", "auto") - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - resume_download = kwargs.pop("resume_download", False) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - token = kwargs.pop("token", None) - from_pipeline = kwargs.pop("_from_pipeline", None) - from_auto_class = kwargs.pop("_from_auto", False) - revision = kwargs.pop("revision", "main") - commit_hash = kwargs.pop("_commit_hash", None) - _fast_init = kwargs.pop("_fast_init", True) - use_safetensors = kwargs.pop("use_safetensors", None) - kwarg_attn_imp = kwargs.pop("attn_implementation", None) - - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - quantization_config = config.quantization_config - - if use_safetensors is None and not is_safetensors_available(): - use_safetensors = False - - if use_auth_token is not None: # pragma: no cover - logger.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " - "Please use `token` instead." + from .utility import convert_dtype_str2torch + + WeightOnlyLinearClass = AutoRoundWeightOnlyLinear + kwargs["groupsize"] = module_quantization_config.get("group_size", 32) + kwargs["scale_dtype"] = convert_dtype_str2torch(module_quantization_config.get("scale_dtype", "fp16")) + else: + from .modules import WeightOnlyLinear as INCWeightOnlyLinear + + WeightOnlyLinearClass = INCWeightOnlyLinear + kwargs["group_size"] = module_quantization_config.get("group_size", 32) + kwargs["g_idx"] = g_idx + + new_module = WeightOnlyLinearClass( + module.in_features, + module.out_features, + dtype=module_quantization_config.get("dtype", "int"), + bits=module_quantization_config.get("bits", 4), + zp=zp, + bias=module.bias is not None, + use_optimum_format=True, + **kwargs, + ) + set_module(self.original_model, name, new_module) + woq_model = self.original_model + return woq_model + + def _check_required_packages(self): + try: + import accelerate + import transformers + except ImportError as e: # pragma: no cover + raise e + + def _get_model_class_and_config(self): + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code + from transformers.models.auto.auto_factory import _get_model_class + + # Autofactory + kwargs_orig = copy.deepcopy(self.kwargs) + trust_remote_code = self.kwargs.pop("trust_remote_code", None) + kwarg_attn_imp = self.kwargs.pop("attn_implementation", None) + + config = AutoConfig.from_pretrained(self.model_name_or_path) + # quantization_config = config.quantization_config + + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover + config._attn_implementation = kwarg_attn_imp + + has_remote_code = hasattr(config, "auto_map") and AutoModelForCausalLM.__name__ in config.auto_map + + has_local_code = type(config) in AutoModelForCausalLM._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + self.model_name_or_path, + has_local_code, + has_remote_code, ) - if token is not None: - raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") - token = use_auth_token - - user_agent = { - "file_type": "model", - "framework": "pytorch", - "from_auto_class": from_auto_class, - } - if from_pipeline is not None: # pragma: no cover - user_agent["using_pipeline"] = from_pipeline - - if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover - config._attn_implementation = kwarg_attn_imp - - if commit_hash is None: - if not isinstance(config, PretrainedConfig): # pragma: no cover - # We make a call to the config file first (which may be absent) - # to get the commit hash as soon as possible. - resolved_config_file = cached_file( - pretrained_model_name_or_path, - "config.json", - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, + + if has_remote_code and trust_remote_code: # pragma: no cover + class_ref = config.auto_map[AutoModelForCausalLM.__name__] + model_class = get_class_from_dynamic_module(class_ref, self.model_name_or_path, **kwargs_orig) + if os.path.isdir(self.model_name_or_path): + model_class.register_for_auto_class(AutoModelForCausalLM.__name__) + else: + AutoModelForCausalLM.register(config.__class__, model_class, exist_ok=True) + elif type(config) in AutoModelForCausalLM._model_mapping.keys(): + model_class = _get_model_class(config, AutoModelForCausalLM._model_mapping) + + return model_class, config + + def _get_loaded_state_dict_keys(self, config): + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_utils import _add_variant, get_checkpoint_shard_files, load_state_dict + from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + cached_file, + download_url, + extract_commit_hash, + has_file, + is_remote_url, + is_safetensors_available, + ) + + subfolder = self.kwargs.pop("subfolder", "") + variant = self.kwargs.pop("variant", None) + cache_dir = self.kwargs.pop("cache_dir", None) + force_download = self.kwargs.pop("force_download", False) + proxies = self.kwargs.pop("proxies", None) + resume_download = self.kwargs.pop("resume_download", False) + local_files_only = self.kwargs.pop("local_files_only", False) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + use_auth_token = self.kwargs.pop("use_auth_token", None) + token = self.kwargs.pop("token", None) + from_pipeline = self.kwargs.pop("_from_pipeline", None) + from_auto_class = self.kwargs.pop("_from_auto", False) + revision = self.kwargs.pop("revision", "main") + commit_hash = self.kwargs.pop("_commit_hash", None) + use_safetensors = self.kwargs.pop("use_safetensors", None) + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + if use_auth_token is not None: # pragma: no cover + logger.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " + "Please use `token` instead." ) - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - else: - commit_hash = getattr(config, "_commit_hash", None) - - has_remote_code = hasattr(config, "auto_map") and AutoModelForCausalLM.__name__ in config.auto_map - - has_local_code = type(config) in AutoModelForCausalLM._model_mapping.keys() - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, - pretrained_model_name_or_path, - has_local_code, - has_remote_code, - ) - - if has_remote_code and trust_remote_code: # pragma: no cover - class_ref = config.auto_map[AutoModelForCausalLM.__name__] - model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) - if os.path.isdir(pretrained_model_name_or_path): - model_class.register_for_auto_class(AutoModelForCausalLM.__name__) - else: - AutoModelForCausalLM.register(config.__class__, model_class, exist_ok=True) - elif type(config) in AutoModelForCausalLM._model_mapping.keys(): - model_class = _get_model_class(config, AutoModelForCausalLM._model_mapping) - - init_contexts = [no_init_weights(_enable=_fast_init)] - init_contexts.append(init_empty_weights()) - - with ContextManagers(init_contexts): - model = model_class(config, *model_args, **kwargs) - - is_sharded = False - sharded_metadata = None - - if pretrained_model_name_or_path is not None: # pragma: no cover - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - if os.path.isfile( - os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(WEIGHTS_NAME, variant), - ) - ): - # Load from a PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(WEIGHTS_NAME, variant), - ) - elif os.path.isfile( - os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(WEIGHTS_INDEX_NAME, variant), - ) - ): - # Load from a sharded PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(WEIGHTS_INDEX_NAME, variant), + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + user_agent = { + "file_type": "model", + "framework": "pytorch", + "from_auto_class": from_auto_class, + } + if from_pipeline is not None: # pragma: no cover + user_agent["using_pipeline"] = from_pipeline + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): # pragma: no cover + # We make a call to the config file first (which may be absent) + # to get the commit hash as soon as possible. + resolved_config_file = cached_file( + self.model_name_or_path, + "config.json", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, ) - is_sharded = True - elif os.path.isfile( - os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(SAFE_WEIGHTS_NAME, variant), - ) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(SAFE_WEIGHTS_NAME, variant), - ) - elif os.path.isfile( - os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), - ) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, - subfolder, - _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), - ) - is_sharded = True - elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): - archive_file = pretrained_model_name_or_path - is_local = True - elif is_remote_url(pretrained_model_name_or_path): - filename = pretrained_model_name_or_path - resolved_archive_file = download_url(pretrained_model_name_or_path) - else: - if use_safetensors is not False: - filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) else: - filename = _add_variant(WEIGHTS_NAME, variant) - try: - # Load from URL or cache if already cached - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "token": token, - "user_agent": user_agent, - "revision": revision, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None - # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, - _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, + commit_hash = getattr(config, "_commit_hash", None) + + is_sharded = False + sharded_metadata = None + + if self.model_name_or_path is not None: # pragma: no cover + self.model_name_or_path = str(self.model_name_or_path) + is_local = os.path.isdir(self.model_name_or_path) + if is_local: + if os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), ) - if resolved_archive_file is not None: - is_sharded = True - elif use_safetensors: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " - f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " - "and thus cannot be loaded with `safetensors`. Please make sure that the model has " - "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." - ) - else: - # This repo has no safetensors file of any kind, we switch to PyTorch. - filename = _add_variant(WEIGHTS_NAME, variant) - resolved_archive_file = cached_file( - pretrained_model_name_or_path, filename, **cached_file_kwargs - ) - if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, ) - if resolved_archive_file is not None: - is_sharded = True - - if resolved_archive_file is None: - # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error - # message. - has_file_kwargs = { - "revision": revision, + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile(os.path.join(subfolder, self.model_name_or_path)): + archive_file = self.model_name_or_path + is_local = True + elif is_remote_url(self.model_name_or_path): + filename = self.model_name_or_path + resolved_archive_file = download_url(self.model_name_or_path) + else: + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, } - if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" - f" {variant}. Use `variant=None` to load this model from those weights." + resolved_archive_file = cached_file(self.model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + self.model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, ) - else: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)}." + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + raise EnvironmentError( + f"{self.model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " + f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + self.model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + self.model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, ) - except EnvironmentError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted - # to the original exception. - raise - except Exception as e: - # For any other exception, we throw a generic error. - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" - " from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." - ) from e - - if is_local: - logger.info(f"loading weights file {archive_file}") - resolved_archive_file = archive_file + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if variant is not None and has_file(self.model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{self.model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{self.model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{self.model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{self.model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: # pragma: no cover + resolved_archive_file = None + + if is_sharded: # pragma: no cover + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + self.model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + self.kwargs["sharded_metadata"] = sharded_metadata + + if is_sharded: # pragma: no cover + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: - logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - else: # pragma: no cover - resolved_archive_file = None - - if is_sharded: # pragma: no cover - # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + + # set kwargs for next functions to use + self.kwargs["is_sharded"] = is_sharded + self.kwargs["offload_folder"] = offload_folder + self.kwargs["offload_state_dict"] = offload_state_dict + self.kwargs["resolved_archive_file"] = resolved_archive_file + + return loaded_state_dict_keys + + def _init_hf_model(self, model_class, config): + from accelerate.big_modeling import init_empty_weights + from transformers.modeling_utils import no_init_weights + from transformers.utils import ContextManagers + + _fast_init = self.kwargs.pop("_fast_init", True) + torch_dtype = self.kwargs.pop("torch_dtype", "auto") + is_sharded = self.kwargs.pop("is_sharded", False) + sharded_metadata = self.kwargs.pop("sharded_metadata", None) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, + # by checking its first weights entry that is of a floating type + # - we assume all floating dtype weights are of the same dtype + dtype_orig = None + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None and config.torch_dtype != "auto": + torch_dtype = config.torch_dtype + else: # pragma: no cover + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + else: + torch_dtype = torch.float32 + else: # pragma: no cover + assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + + dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + + init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts.append(init_empty_weights()) + + with ContextManagers(init_contexts): + model = model_class(config, **self.kwargs) + + # set kwargs for next functions to use + self.kwargs["resolved_archive_file"] = resolved_archive_file + self.kwargs["sharded_metadata"] = sharded_metadata + self.kwargs["torch_dtype"] = torch_dtype + self.kwargs["dtype_orig"] = dtype_orig + self.kwargs["_fast_init"] = _fast_init + self.kwargs["offload_folder"] = offload_folder + self.kwargs["offload_state_dict"] = offload_state_dict + + return model + + def _load_pretrained_weight(self, model, model_class): + resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) + sharded_metadata = self.kwargs.pop("sharded_metadata", None) + torch_dtype = self.kwargs.pop("torch_dtype", torch.float32) + dtype_orig = self.kwargs.pop("dtype_orig", None) + _fast_init = self.kwargs.pop("_fast_init", True) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + self.loaded_state_dict_keys, resolved_archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _commit_hash=commit_hash, + self.model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], ) - # set dtype to instantiate the model under: - # 1. If torch_dtype is not None, we use that dtype - # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, - # by checking its first weights entry that is of a floating type - # - we assume all floating dtype weights are of the same dtype - dtype_orig = None - if torch_dtype is not None: - if isinstance(torch_dtype, str): - if torch_dtype == "auto": - if hasattr(config, "torch_dtype") and config.torch_dtype is not None and config.torch_dtype != "auto": - torch_dtype = config.torch_dtype - else: # pragma: no cover - if is_sharded and "dtype" in sharded_metadata: - torch_dtype = sharded_metadata["dtype"] - else: - torch_dtype = torch.float32 - else: # pragma: no cover - assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' - - dtype_orig = model_class._set_default_torch_dtype(torch_dtype) - - if is_sharded: # pragma: no cover - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - # Time to load the checkpoint - state_dict = load_state_dict(resolved_archive_file) - loaded_state_dict_keys = list(state_dict.keys()) - - model = _build_woq_model(model, quantization_config, loaded_state_dict_keys) - - # restore default dtype - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) - - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = model_class._load_pretrained_model( - model, - None, - loaded_state_dict_keys, - resolved_archive_file, - pretrained_model_name_or_path, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=True, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - keep_in_fp32_modules=[], - ) - - # make sure token embedding weights are still tied if needed - model.tie_weights() - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() + # make sure token embedding weights are still tied if needed + model.tie_weights() - return model + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + return model diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index c5e0f4df369..f23727cb4c4 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -32,7 +32,7 @@ } -def load(model_name_or_path, original_model=None, format="default", device="cpu", *model_args, **kwargs): +def load(model_name_or_path, original_model=None, format="default", device="cpu", **kwargs): """Load quantized model. 1. Load INC quantized model in local. @@ -61,19 +61,17 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" If 'format' is set to 'default', it means the 'checkpoint_dir'. Parameter should not be None. it coworks with 'original_model' parameter to load INC quantized model in local. - original_model (torch.nn.module, optional): original model before quantization. - Needed if 'format' is set to 'default' and not TorchScript model.Defaults to None. + original_model (torch.nn.module or TorchScript model with IPEX or fx graph with pt2e, optional): + original model before quantization. Needed if 'format' is set to 'default' and not TorchScript model. + Defaults to None. format (str, optional): 'defult' for loading INC quantized model. 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". device (str, optional): 'cpu', 'hpu' or 'cuda'. specify the device the model will be loaded to. - model_args (sequence of positional arguments, optional): - all remaining positional arguments for loading huggingface models. - Will be passed to the huggingface model's `__init__` method. kwargs (remaining dictionary of keyword arguments, optional): remaining dictionary of keyword arguments for loading huggingface models. Will be passed to the huggingface model's `__init__` method, such as 'trust_remote_code', 'revision'. Returns: - torch.nn.Module: quantized model + The quantized model """ # TODO: When loading WOQ model, use different WeightOnlyLinear module according to device. if format == LoadFormat.DEFAULT.value: @@ -106,6 +104,6 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" # now only support load huggingface WOQ causal language model from neural_compressor.torch.algorithms.weight_only.save_load import load - return load(model_name_or_path, format=LoadFormat.HUGGINGFACE, *model_args, **kwargs) + return load(model_name_or_path, format=LoadFormat.HUGGINGFACE, **kwargs) else: raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py index 887332a2ad0..c12197d211c 100644 --- a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -13,6 +13,12 @@ def setup_class(self): def test_load_hf_woq_model(self): from neural_compressor.torch.quantization import load - qmodel = load(model_name_or_path=self.model_name, format="huggingface") + qmodel = load(model_name_or_path=self.model_name, format="huggingface", torch_dtype=torch.float32) + + woq_linear_num = 0 + for _, module in qmodel.named_modules(): + if module.__class__.__name__ == "WeightOnlyLinear": + woq_linear_num += 1 + assert woq_linear_num == 154, "Incorrect number of WeightOnlyLinear modules" output = qmodel(self.example_inputs)[0] assert len(output) > 0, "Not loading the model correctly" From f18602caa201cfc4f1d63d8013699133aa1eec3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 02:20:04 +0000 Subject: [PATCH 21/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/save_load.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index aac5290a424..eaaeed8a57f 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -14,10 +14,11 @@ # pylint:disable=import-error +import copy import json import os import re -import copy + import torch from neural_compressor.common.utils import load_config_mapping, save_config_mapping @@ -89,25 +90,25 @@ def load_woq_model(self): logger.info("Quantized huggingface model loading successful.") elif self.format == LoadFormat.DEFAULT: qmodel_weight_file_path = os.path.join( - os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME) + os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME + ) assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( qmodel_weight_file_path ) - qconfig_file_path = os.path.join( - os.path.abspath(os.path.expanduser(self.model_name_or_path)), QCONFIG_NAME) + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(self.model_name_or_path)), QCONFIG_NAME) assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( qconfig_file_path ) - assert self.original_model is not None, \ - "Can't get original model. Please pass `original_model` to load function." + assert ( + self.original_model is not None + ), "Can't get original model. Please pass `original_model` to load function." model = self.load_inc_format_woq_model(qmodel_weight_file_path, qconfig_file_path) logger.info("Quantized model loading successful.") else: - raise ValueError( - f"`format` in load function can only be 'huggingface' or 'default', but get {self.format}") + raise ValueError(f"`format` in load function can only be 'huggingface' or 'default', but get {self.format}") return model @@ -135,7 +136,7 @@ def load_hf_format_woq_model(self): self.loaded_state_dict_keys = self._get_loaded_state_dict_keys(config) # initiate the huggingface model - self.original_model = self._init_hf_model(model_class, config) + self.original_model = self._init_hf_model(model_class, config) # build weight-only quantization model with WeightOnlyLinear module model = self._build_woq_model() @@ -192,7 +193,9 @@ def _build_woq_model(self): WeightOnlyLinearClass = AutoRoundWeightOnlyLinear kwargs["groupsize"] = module_quantization_config.get("group_size", 32) - kwargs["scale_dtype"] = convert_dtype_str2torch(module_quantization_config.get("scale_dtype", "fp16")) + kwargs["scale_dtype"] = convert_dtype_str2torch( + module_quantization_config.get("scale_dtype", "fp16") + ) else: from .modules import WeightOnlyLinear as INCWeightOnlyLinear @@ -301,7 +304,9 @@ def _get_loaded_state_dict_keys(self, config): "Please use `token` instead." ) if token is not None: - raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) token = use_auth_token user_agent = { @@ -446,9 +451,7 @@ def _get_loaded_state_dict_keys(self, config): else: # This repo has no safetensors file of any kind, we switch to PyTorch. filename = _add_variant(WEIGHTS_NAME, variant) - resolved_archive_file = cached_file( - self.model_name_or_path, filename, **cached_file_kwargs - ) + resolved_archive_file = cached_file(self.model_name_or_path, filename, **cached_file_kwargs) if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( @@ -554,7 +557,11 @@ def _init_hf_model(self, model_class, config): if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": - if hasattr(config, "torch_dtype") and config.torch_dtype is not None and config.torch_dtype != "auto": + if ( + hasattr(config, "torch_dtype") + and config.torch_dtype is not None + and config.torch_dtype != "auto" + ): torch_dtype = config.torch_dtype else: # pragma: no cover if is_sharded and "dtype" in sharded_metadata: From a33c766e8865f009ab02fa4ccd24cedd85574113 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 12 Jun 2024 08:14:29 +0000 Subject: [PATCH 22/24] enhance code Signed-off-by: yuwenzho --- .../torch/algorithms/weight_only/__init__.py | 3 ++ .../torch/algorithms/weight_only/save_load.py | 34 +++++++-------- .../torch/quantization/load_entry.py | 43 +++++++++---------- 3 files changed, 41 insertions(+), 39 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/__init__.py b/neural_compressor/torch/algorithms/weight_only/__init__.py index 28f108cb636..fc9ef0a5b3b 100644 --- a/neural_compressor/torch/algorithms/weight_only/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from .save_load import save, load diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index eaaeed8a57f..3a9ff2effa8 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -68,7 +68,7 @@ def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, dev Returns: torch.nn.Module: quantized model """ - model_loader = WOQModelLoader(model_name_or_path, original_model, format, device="cpu", **kwargs) + model_loader = WOQModelLoader(model_name_or_path, original_model, format, device, **kwargs) model = model_loader.load_woq_model() return model @@ -87,26 +87,34 @@ def __init__(self, model_name_or_path, original_model=None, format=LoadFormat.DE def load_woq_model(self): if self.format == LoadFormat.HUGGINGFACE: model = self.load_hf_format_woq_model() - logger.info("Quantized huggingface model loading successful.") + logger.info("Loading HuggingFace weight-only quantization model successfully.") elif self.format == LoadFormat.DEFAULT: qmodel_weight_file_path = os.path.join( os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME ) - assert os.path.exists(qmodel_weight_file_path), "Cannot load model weight from path {}".format( - qmodel_weight_file_path - ) + assert os.path.exists(qmodel_weight_file_path), \ + "Cannot load model weight from path {}. " \ + "Please make sure '{}' file is saved in your '{}' directory ".format( + qmodel_weight_file_path, + WEIGHT_NAME, + self.model_name_or_path + ) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(self.model_name_or_path)), QCONFIG_NAME) - assert os.path.exists(qconfig_file_path), "Cannot load model quantization config from path {}".format( - qconfig_file_path - ) + assert os.path.exists(qconfig_file_path), \ + "Cannot load model quantization config from path {}. " \ + "Please make sure '{}' file is saved in your '{}' directory".format( + qconfig_file_path, + QCONFIG_NAME, + self.model_name_or_path + ) assert ( self.original_model is not None ), "Can't get original model. Please pass `original_model` to load function." model = self.load_inc_format_woq_model(qmodel_weight_file_path, qconfig_file_path) - logger.info("Quantized model loading successful.") + logger.info("Loading weight-only quantization model successfully.") else: raise ValueError(f"`format` in load function can only be 'huggingface' or 'default', but get {self.format}") @@ -126,7 +134,6 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path): def load_hf_format_woq_model(self): # check required package - self._check_required_packages() # get model_class and config model_class, config = self._get_model_class_and_config() @@ -217,13 +224,6 @@ def _build_woq_model(self): woq_model = self.original_model return woq_model - def _check_required_packages(self): - try: - import accelerate - import transformers - except ImportError as e: # pragma: no cover - raise e - def _get_model_class_and_config(self): from transformers import AutoConfig, AutoModelForCausalLM from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index f23727cb4c4..d20f828659d 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -36,24 +36,23 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" """Load quantized model. 1. Load INC quantized model in local. - 2. Load HuggingFace quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + case 1: WOQ + from neural_compressor.torch.quantization import load + load(model_name_or_path="saved_results", original_model=fp32_model) - case 1: WOQ - # huggingface model - from neural_compressor.torch.quantization import load - load(model_name_or_path=model_name_or_path) + case 2: INT8/FP8 + from neural_compressor.torch.quantization import load + load(model_name_or_path='saved_result', original_model=fp32_model) - # local model - from neural_compressor.torch.quantization import load - load(model_name_or_path="saved_results", original_model=fp32_model) + case 3: TorchScript (IPEX) + from neural_compressor.torch.quantization import load + load(model_name_or_path='saved_result') - case 2: INT8/FP8 - from neural_compressor.torch.quantization import load - load(model_name_or_path='saved_result', original_model=fp32_model) + 2. Load HuggingFace quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + case 1: WOQ + from neural_compressor.torch.quantization import load + load(model_name_or_path=model_name_or_path) - case 3: TorchScript (IPEX) - from neural_compressor.torch.quantization import load - load(model_name_or_path='saved_result') Args: model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path. @@ -82,28 +81,28 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" per_op_qconfig = json.load(f) if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... - from neural_compressor.torch.algorithms.static_quant import load + from neural_compressor.torch.algorithms import static_quant - return load(model_name_or_path) + return static_quant.load(model_name_or_path) else: config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) # select load function config_object = config_mapping[next(iter(config_mapping))] if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ - from neural_compressor.torch.algorithms.weight_only.save_load import load + from neural_compressor.torch.algorithms import weight_only - return load(model_name_or_path, original_model, format=LoadFormat.DEFAULT) + return weight_only.load(model_name_or_path, original_model, format=LoadFormat.DEFAULT) original_model.qconfig = config_mapping if isinstance(config_object, FP8Config): # FP8 - from neural_compressor.torch.algorithms.habana_fp8 import load + from neural_compressor.torch.algorithms import habana_fp8 - return load(model_name_or_path, original_model) + return habana_fp8.load(model_name_or_path, original_model) elif format == LoadFormat.HUGGINGFACE.value: # now only support load huggingface WOQ causal language model - from neural_compressor.torch.algorithms.weight_only.save_load import load + from neural_compressor.torch.algorithms import weight_only - return load(model_name_or_path, format=LoadFormat.HUGGINGFACE, **kwargs) + return weight_only.load(model_name_or_path, format=LoadFormat.HUGGINGFACE, **kwargs) else: raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) From 42cfe39722bad7366a66ff3701e7d3e68e0e31e8 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 12 Jun 2024 08:54:21 +0000 Subject: [PATCH 23/24] enhance code Signed-off-by: yuwenzho --- neural_compressor/torch/algorithms/weight_only/save_load.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 3a9ff2effa8..b200988038d 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -134,6 +134,11 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path): def load_hf_format_woq_model(self): # check required package + from neural_compressor.torch.utils import is_package_available + if not is_package_available("transformers"): + raise ImportError("Loading huggingface model requires transformers: `pip install transformers`") + if not is_package_available("accelerate"): + raise ImportError("Loading huggingface model requires accelerate: `pip install accelerate`") # get model_class and config model_class, config = self._get_model_class_and_config() From 61b70cc6e806e87e151321701163d461e1ff1cc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 08:56:13 +0000 Subject: [PATCH 24/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/save_load.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index b200988038d..7494dac86f9 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -92,22 +92,20 @@ def load_woq_model(self): qmodel_weight_file_path = os.path.join( os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME ) - assert os.path.exists(qmodel_weight_file_path), \ - "Cannot load model weight from path {}. " \ + assert os.path.exists(qmodel_weight_file_path), ( + "Cannot load model weight from path {}. " "Please make sure '{}' file is saved in your '{}' directory ".format( - qmodel_weight_file_path, - WEIGHT_NAME, - self.model_name_or_path - ) + qmodel_weight_file_path, WEIGHT_NAME, self.model_name_or_path + ) + ) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(self.model_name_or_path)), QCONFIG_NAME) - assert os.path.exists(qconfig_file_path), \ - "Cannot load model quantization config from path {}. " \ + assert os.path.exists(qconfig_file_path), ( + "Cannot load model quantization config from path {}. " "Please make sure '{}' file is saved in your '{}' directory".format( - qconfig_file_path, - QCONFIG_NAME, - self.model_name_or_path + qconfig_file_path, QCONFIG_NAME, self.model_name_or_path ) + ) assert ( self.original_model is not None @@ -135,6 +133,7 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path): def load_hf_format_woq_model(self): # check required package from neural_compressor.torch.utils import is_package_available + if not is_package_available("transformers"): raise ImportError("Loading huggingface model requires transformers: `pip install transformers`") if not is_package_available("accelerate"):