diff --git a/neural_compressor/transformers/__init__.py b/neural_compressor/transformers/__init__.py new file mode 100644 index 00000000000..7701cea89d9 --- /dev/null +++ b/neural_compressor/transformers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils.quantization_config import GPTQConfig, RtnConfig diff --git a/neural_compressor/transformers/models/__init__.py b/neural_compressor/transformers/models/__init__.py new file mode 100644 index 00000000000..fcaf093c802 --- /dev/null +++ b/neural_compressor/transformers/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_auto import _BaseINCAutoModelClass diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py new file mode 100644 index 00000000000..b200548d4cc --- /dev/null +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -0,0 +1,623 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import types + +from accelerate import init_empty_weights +from accelerate.utils import is_xpu_available + +from neural_compressor.adaptor.torch_utils.util import set_module +from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear +from neural_compressor.transformers import GPTQConfig, RtnConfig +from neural_compressor.transformers.quantization.utils import convert_dtype_torch2str, replace_linear, save_low_bit +from neural_compressor.utils import logger +from neural_compressor.utils.utility import CpuInfo, LazyImport + +torch = LazyImport("torch") +transformers = LazyImport("transformers") +transformers_configuration_utils = LazyImport("transformers.configuration_utils") + + +def build_woq_model(model, quantization_config): + bits = quantization_config.bits + for n, m in model.named_modules(): + if n in quantization_config.modules_to_not_convert: + continue + if isinstance(m, torch.nn.Linear): + zp = getattr( + quantization_config, + "zero_point", + not getattr(quantization_config, "sym", False), + ) + use_optimum_format = True + with init_empty_weights(): + new_module = INCWeightOnlyLinear( + m.in_features, + m.out_features, + dtype="int4" if bits == 4 else "int8", + bits=quantization_config.bits, + group_size=quantization_config.group_size, + zp=zp, + bias=m.bias is not None, + g_idx=True, + use_optimum_format=use_optimum_format, + ) + set_module(model, n, new_module) + return model + + +class _BaseINCAutoModelClass: + ORIG_MODEL = None + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + config = kwargs.pop("config", None) + if not isinstance(config, transformers_configuration_utils.PretrainedConfig): + config, _ = transformers.AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + **kwargs, + ) + + if hasattr(config, "quantization_config"): + if config.quantization_config is None: + logger.warning( + "Quantization_config loading failed. If you want to load saved " + "low bit model, please check your quantizate_config.json." + ) + + else: + logger.info("quantization_config: {}".format(config.quantization_config)) + try: + model = cls.load_low_bit( + pretrained_model_name_or_path, + *model_args, + config=config, + **kwargs, + ) + logger.info("Saved low bit model loading successfully. Other input args " "will be ignored.") + return model + except Exception as e: + logger.error(e) + logger.error("Saved low bit model loading failed, please check your model.") + exit(0) + + @classmethod + def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): + """Load a low bit optimized model (including INT4, INT5 and INT8) from a saved ckpt. + + :param pretrained_model_name_or_path: str value, Path to load the optimized model ckpt. + # :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. + # Default to be True. + :return: a model instance + """ + from accelerate.big_modeling import init_empty_weights + from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code + from transformers.generation.configuration_utils import GenerationConfig + from transformers.modeling_utils import ( + _add_variant, + get_checkpoint_shard_files, + load_state_dict, + no_init_weights, + ) + from transformers.models.auto.auto_factory import _get_model_class + from transformers.models.auto.configuration_auto import AutoConfig + from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ContextManagers, + cached_file, + download_url, + extract_commit_hash, + has_file, + is_remote_url, + is_safetensors_available, + ) + + # Autofactory + kwargs_orig = copy.deepcopy(kwargs) + # modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + # Maybe needed when extract_local_archive_file + subfolder = kwargs.pop("subfolder", "") + variant = kwargs.pop("variant", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + torch_dtype = kwargs.pop("torch_dtype", "auto") + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + resume_download = kwargs.pop("resume_download", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + token = kwargs.pop("token", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + revision = kwargs.pop("revision", "main") + commit_hash = kwargs.pop("_commit_hash", None) + _fast_init = kwargs.pop("_fast_init", True) + device_map = kwargs.pop("device_map", "xpu" if is_xpu_available() else "cpu") + use_safetensors = kwargs.pop("use_safetensors", None) + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + + # lm-eval device map is dictionary + device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + if use_auth_token is not None: + logger.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " + "Please use `token` instead." + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False + use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False + + user_agent = { + "file_type": "model", + "framework": "pytorch", + "from_auto_class": from_auto_class, + } + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + config = kwargs.pop("config", None) + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp + + quantization_config = config.quantization_config + + if quantization_config["quant_method"] == "rtn": + quantization_config = RtnConfig.from_dict(quantization_config) + elif quantization_config["quant_method"] == "gptq": + quantization_config = GPTQConfig.from_dict(quantization_config) + + assert quantization_config is not None, "Detect this model is not a low-bit model." + + if commit_hash is None: + if not isinstance(config, transformers_configuration_utils.PretrainedConfig): + # We make a call to the config file first (which may be absent) + # to get the commit hash as soon as possible. + resolved_config_file = cached_file( + pretrained_model_name_or_path, + "config.json", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map + + has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + ) + if has_remote_code and trust_remote_code: + class_ref = config.auto_map[cls.ORIG_MODEL.__name__] + model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) + if os.path.isdir(pretrained_model_name_or_path): + model_class.register_for_auto_class(cls.ORIG_MODEL.__name__) + else: + cls.ORIG_MODEL.register(config.__class__, model_class, exist_ok=True) + elif type(config) in cls.ORIG_MODEL._model_mapping.keys(): + model_class = _get_model_class(config, cls.ORIG_MODEL._model_mapping) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " + f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + filename, + **cached_file_kwargs, + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if variant is not None and has_file( + pretrained_model_name_or_path, + WEIGHTS_NAME, + **has_file_kwargs, + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) from e + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, + # by checking its first weights entry that is of a floating type + # - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + # Pretrained Model + + dtype_orig = None + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if ( + hasattr(config, "torch_dtype") + and config.torch_dtype is not None + and config.torch_dtype != "auto" + ): + torch_dtype = config.torch_dtype + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + else: + torch_dtype = torch.float32 + else: + assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + + dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + if quantization_config.compute_dtype is None: + if use_xpu: + quantization_config.compute_dtype = ( + "fp16" + if (torch_dtype is None or torch_dtype == torch.bfloat16) + else convert_dtype_torch2str(torch_dtype) + ) + else: + quantization_config.compute_dtype = ( + "fp32" + if ( + torch_dtype is None + or (not CpuInfo().bf16 and torch_dtype == torch.bfloat16) + or (torch_dtype == torch.float16) + ) + else convert_dtype_torch2str(torch_dtype) + ) + else: + if (not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16") or ( + use_cpu and quantization_config.compute_dtype == "fp16" + ): + quantization_config.compute_dtype = "fp32" + + if quantization_config.scale_dtype is None: + quantization_config.scale_dtype = "fp32" + if quantization_config.scale_dtype not in ["fp32", "fp16", "bf16"]: + logger.warning("scale_dtype only supports fp32, bf16, fp16.") + quantization_config.scale_dtype = "fp32" + logger.warning("fp32 scale_dtype is used, please change the config.json if you don't want to use it.") + + # weight dtype is higher priority than bits in config.json when both existed. + if quantization_config.bits == 4: + if use_xpu: + quantization_config.weight_dtype = "int4_fullrange" + else: + quantization_config.weight_dtype = "int4" + logger.info( + "{} quantization weight_dtype is used due to bits is 4 in config.json.".format( + quantization_config.weight_dtype + ) + ) + elif quantization_config.bits == 8: + quantization_config.weight_dtype = "int8" + logger.info( + "{} quantization weight_dtype is used due to bits is 8 in config.json.".format( + quantization_config.weight_dtype + ) + ) + else: + logger.warning("bits number only supports 4, 8.") + quantization_config.weight_dtype = "int4" + logger.warning("int4 weight_dtype is used, please change the config.json if you don't want to use it.") + + init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts.append(init_empty_weights()) + + with ContextManagers(init_contexts): + model = model_class(config, *model_args, **kwargs) + + model = build_woq_model(model, quantization_config) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + model = replace_linear( + model, + quantization_config=quantization_config, + device="cpu" if device_map == "auto" else device_map, + empty_weights=True, + ) + + if (not use_xpu and torch_dtype == torch.float16) or ( + not use_xpu and not CpuInfo().bf16 and torch_dtype == torch.bfloat16 + ): + model.to(dtype=torch.float32) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except (OSError, TypeError): + pass + for param in model.parameters(): + param.requires_grad_(False) + if device_map == "xpu": + model = model.to("xpu") + model.quantization_config = quantization_config + model.save_pretrained = types.MethodType(save_low_bit, model) + return model + + +class AutoModelForCausalLM(_BaseINCAutoModelClass): + ORIG_MODEL = transformers.AutoModelForCausalLM + + +class AutoModel(_BaseINCAutoModelClass): + ORIG_MODEL = transformers.AutoModel + + +class AutoModelForSeq2SeqLM(_BaseINCAutoModelClass): + ORIG_MODEL = transformers.AutoModelForSeq2SeqLM diff --git a/neural_compressor/transformers/quantization/__init__.py b/neural_compressor/transformers/quantization/__init__.py new file mode 100644 index 00000000000..5dd8b2769f6 --- /dev/null +++ b/neural_compressor/transformers/quantization/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import convert_to_quantized_model, save_low_bit diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py new file mode 100644 index 00000000000..d4739a38562 --- /dev/null +++ b/neural_compressor/transformers/quantization/utils.py @@ -0,0 +1,529 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Intel Neural Compressor model convert.""" + +import gc +import json +import logging +import math +import os +import types + +from datasets import load_dataset + +from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear +from neural_compressor.torch.quantization import GPTQConfig, RTNConfig, convert, prepare +from neural_compressor.torch.utils import is_ipex_available +from neural_compressor.utils.utility import CpuInfo, LazyImport + +if is_ipex_available(): + import intel_extension_for_pytorch as ipex + +from typing import Union + +torch = LazyImport("torch") + + +logger = logging.getLogger(__name__) + + +def convert_dtype_str2torch(str_dtype): + if str_dtype == "int8": + return torch.int8 + elif str_dtype == "fp32" or str_dtype == "auto": + return torch.float + elif str_dtype == "fp16": + return torch.float16 + elif str_dtype == "bf16": + return torch.bfloat16 + else: + assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype) + + +def convert_dtype_torch2str(dtype): + if dtype == torch.int8: + return "int8" + elif dtype == torch.float: + return "fp32" + elif dtype == torch.float16: + return "fp16" + elif dtype == torch.bfloat16: + return "bf16" + elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]: + return dtype + else: + assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype) + + +def replace_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + device="cpu", + empty_weights=False, +): + if modules_to_not_convert is None: + # output_layer is chatglm last layer name + # embed_out is dolly_v2 last layer name + modules_to_not_convert = [] + if quantization_config.modules_to_not_convert: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, is_replaced = _replace_linear( + model, + modules_to_not_convert, + current_key_name, + quantization_config, + device=device, + empty_weights=empty_weights, + ) + + if not is_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model + + +def _replace_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + is_replaced=False, + device="cpu", + empty_weights=False, +): + """Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfully or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + is_removed = False + if ( + isinstance(module, torch.nn.Linear) + or isinstance(module, INCWeightOnlyLinear) + or (is_ipex_available() and isinstance(module, ipex.nn.utils._weight_prepack._IPEXLinear)) + ) and (name not in modules_to_not_convert): + # Check if the current key is not in the `modules_to_not_convert` + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + in_features = module.in_features + out_features = module.out_features + if device == "cpu" or device == torch.device("cpu") or device == "auto": + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear + from intel_extension_for_pytorch.utils.weight_only_quantization import ( + _convert_optimum_format_to_desired, + ) + + qweight = module.qweight + scales = module.scales + qzeros = module.qzeros + + qweight, scales, qzeros = _convert_optimum_format_to_desired(qweight, scales, qzeros) + weight_dtype = { + 4: ipex.quantization.WoqWeightDtype.INT4, + 8: ipex.quantization.WoqWeightDtype.INT8, + } + compute_dtype = { + "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. + "bf16": ipex.quantization.WoqLowpMode.BF16, + "fp16": ipex.quantization.WoqLowpMode.FP16, + "int8": ipex.quantization.WoqLowpMode.INT8, + } + + ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype[quantization_config.bits], + lowp_mode=compute_dtype[quantization_config.compute_dtype], + act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + group_size=quantization_config.group_size, + ) + tmp_linear = torch.nn.Linear( + in_features, + out_features, + True if hasattr(module, "bias") and module.bias is not None else False, + ) + if tmp_linear.bias is not None and module.bias is not None: + tmp_linear.bias = torch.nn.Parameter(module.bias.float()) + + tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig + model._modules[name] = ipex_linear.from_float_and_int4_weight( + mod=tmp_linear, + qweight=qweight, + scales=scales, + zero_points=qzeros, + bias=(module.bias.float() if hasattr(module, "bias") and module.bias is not None else None), + group_size=quantization_config.group_size, + g_idx=(module.g_idx if hasattr(module, "g_idx") else None), + ) + + elif device == "xpu" or device == torch.device("xpu"): + from intel_extension_for_pytorch.nn.utils._quantize_convert import ( + WeightOnlyQuantizedLinear as ipex_linear, # pylint: disable=E0401 + ) + + model._modules[name] = ipex_linear( + in_features, + out_features, + module.bias is not None, + compute_dtype=quantization_config.compute_dtype, + compress_statistics=False, + weight_dtype=quantization_config.weight_dtype, + scale_dtype=quantization_config.scale_dtype, + blocksize=quantization_config.group_size, + scheme=quantization_config.scheme, + compression_dtype=getattr(module, "compression_dtype", torch.int32), + compression_dim=getattr(module, "compression_dim", 1), + device=device, + use_optimum_format=getattr(module, "use_optimum_format", True), + ) + if quantization_config.quant_method.value == "gptq": + g_idx = getattr( + module, + "g_idx", + torch.zeros(in_features, dtype=torch.int32).to(device), + ) + else: + g_idx = None + model._modules[name].set_scales_zps_gidx( + ( + module.scales + if hasattr(module, "scales") + else torch.ones( + ( + math.ceil(in_features / quantization_config.group_size), + out_features, + ), + dtype=convert_dtype_str2torch(quantization_config.compute_dtype), + device=torch.device(device), + ) + ), + module.qzeros if hasattr(module, "qzeros") else None, + g_idx, + ) + else: + raise Exception("{} device Unsupported weight only quantization!".format(device)) + + is_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + if device == "xpu" or device == torch.device("xpu"): + if not hasattr(module, "qweight"): + n_pack = 32 // quantization_config.bits + + weight = torch.zeros( + (math.ceil(in_features / n_pack), out_features), + dtype=torch.int32, + device=torch.device(device), + ) + model._modules[name].set_weights_bias( + module.qweight.data if hasattr(module, "qweight") else weight, + None if module.bias is None else module.bias.data, + ) + del module + gc.collect() + is_removed = True + + if not is_removed and len(list(module.children())) > 0: # pylint: disable=E1101 + _, is_replaced = _replace_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + is_replaced=is_replaced, + device=device, + empty_weights=empty_weights, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, is_replaced + + +def default_run_fn(model, tokenizer, dataset, max_length=512, n_samples=100, batch_size=8, algo="rtn"): + from torch.utils.data import DataLoader + + if isinstance(dataset, (str, bytes, os.PathLike)): + calib_dataset = load_dataset(dataset, split="train") + calib_dataset = calib_dataset.shuffle(seed=42) + if tokenizer is None: + logger.error("Please provide the tokenizer in quantization_config.") + exit(0) + + def tokenize_function(examples): + if "prompt" in examples: + example = tokenizer(examples["prompt"]) + elif "code" in examples: + example = tokenizer(examples["code"]) + elif "text" in examples: + example = tokenizer(examples["text"]) + else: + logger.error( + "Please check dataset prompt identifier," + " NeelNanda/pile-10k is default used calibration dataset." + ) + exit(0) + return example + + tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type="torch", columns=["input_ids"]) + tokenized_dataset = tokenized_dataset.filter(lambda x: x["input_ids"].shape[-1] >= max_length) + + def collate_batch(batch): + input_ids_padded = [] + for text in batch: + input_ids = text["input_ids"] + if len(input_ids) >= max_length: + input_ids = input_ids[:max_length] + input_ids_padded.append(input_ids) + else: + continue + assert ( + input_ids_padded != [] + ), "The dataset does not have data that meets the required input length. Please reduce seq_len." + return torch.vstack(input_ids_padded) + + calib_dataloader = DataLoader( + tokenized_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_batch, + ) + total_cnt = 0 + for i, (input_ids) in enumerate(calib_dataloader): + if total_cnt + input_ids.shape[0] > n_samples: + input_ids = input_ids[: n_samples - total_cnt, ...] + total_cnt += input_ids.shape[0] + if total_cnt >= n_samples: + break + + try: + model( + input_ids=input_ids, + ) + except ValueError: + pass + + +def convert_to_quantized_model(model, config, device="cpu"): + if device == "xpu" or device == torch.device("xpu"): + import intel_extension_for_pytorch + + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!" + os.environ["FORCE_DEVICE"] = "cpu" + logger.info( + "Set the environment variable FORCE_DEVICE='cpu' to ensure the quantization process occurs on the CPU." + ) + + orig_dtype = torch.float32 + for param in model.parameters(): + orig_dtype = param.dtype + if orig_dtype != torch.float32: + model.to(dtype=torch.float32) + break + + # mapping to INC config + dtype = "int4" if config.weight_dtype == "int4_fullrange" else config.weight_dtype + if config.quant_method.value == "rtn": + quant_config = RTNConfig(dtype=dtype, bits=config.bits, use_sym=config.sym, group_size=config.group_size) + if config.use_layer_wise: + quant_config.user_layer_wise = config.use_layer_wise + quant_config.model_path = config.model_path + if config.modules_to_not_convert != []: + for module in config.modules_to_not_convert: + module_name = ".*" + module + quant_config.set_local(module_name, RTNConfig(dtype="fp32")) + logger.info(f"Do RTN algorithm with config {quant_config}") + model = prepare(model, quant_config) + model = convert(model) + elif config.quant_method.value == "gptq": + model.seqlen = config.seq_len + quant_config = GPTQConfig( + dtype=dtype, + bits=config.bits, + use_sym=config.sym, + group_size=config.group_size, + use_layer_wise=config.use_layer_wise, + act_order=config.desc_act, + percdamp=config.damp_percent, + block_size=config.blocksize, + static_groups=config.static_groups, + use_mse_search=config.use_mse_search, + true_sequential=config.true_sequential, + ) + if config.use_layer_wise: + quant_config.user_layer_wise = config.use_layer_wise + quant_config.model_path = config.model_path + if config.modules_to_not_convert != []: + for module in config.modules_to_not_convert: + module_name = ".*" + module + quant_config.set_local(module_name, GPTQConfig(dtype="fp32")) + logger.info(f"Do GPTQ algorithm with config {quant_config}") + run_fn = default_run_fn + run_args = ( + config.tokenizer, + config.dataset, + config.seq_len, # max_length + config.n_samples, # n_samples + config.batch_size, # batch_size + config.quant_method.value, # algo + ) + model = prepare(model=model, quant_config=quant_config) + run_fn(model, *run_args) + model = convert(model) + else: + assert False, "The Supported algorithm are RTN, GPTQ." + + if device == "xpu" or device == torch.device("xpu"): + logger.warning("The recommended ipex version is higher than 2.3.10 for xpu device.") + + model.eval() + + q_model = replace_linear(model, None, None, config, device=device) + + if orig_dtype != torch.float32: + q_model.to(dtype=orig_dtype) + + return q_model.to(device) + + +def convert_to_GPTQ_checkpoints(model, quantization_config): + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_cpu_linear + + from neural_compressor.adaptor.torch_utils.util import set_module + from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear + + dtype = "int4" if quantization_config.bits == 4 else "int8" + bits = quantization_config.bits + group_size = quantization_config.group_size + zp = False if quantization_config.sym else True + scale_dtype = quantization_config.scale_dtype + desc_act = (True if hasattr(quantization_config, "desc_act") else False,) + + for name, module in model.named_modules(): + if isinstance(module, ipex_cpu_linear): + in_features = module.in_features + out_features = module.out_features + new_module = INCWeightOnlyLinear( + in_features, + out_features, + dtype=dtype, + bits=bits, + group_size=group_size, + zp=zp, + bias=True if hasattr(module, "bias") else False, + scale_dtype=scale_dtype, + g_idx=desc_act, + use_optimum_format=True, + ) + + new_module.bits = 8 + new_module.n_pack = 32 // 8 + qweight = ( + new_module.pack_tensor_with_numpy(module._op_context.to_public(module._op_context.get_weight())) + .t() + .contiguous() + ) + new_module.bits = bits + new_module.n_pack = 32 // bits + scales = module._op_context.get_scales().t().contiguous() + bias = module._op_context.get_bias() + qzeros = new_module.pack_tensor_with_numpy( + module._op_context.get_zero_points().t().to(torch.uint8) - 1 + ).contiguous() + g_idx = module._op_context.get_g_idx() + + new_module.qweight = qweight + new_module.scales = scales + new_module.qzeros = qzeros + if g_idx is not None: + new_module.g_idx = g_idx.contiguous() + if bias is not None: + new_module.bias = bias.contiguous() + + set_module(model, name, new_module) + return model + + +def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + + assert hasattr(self, "quantization_config"), "Detected this model is not a low-bit model." + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + # use transformers original `save_pretrained` function + del self.save_pretrained + + if self.device == "cpu" or self.device == torch.device("cpu"): + convert_to_GPTQ_checkpoints(self, self.quantization_config) + if self.device == "xpu" or (isinstance(self.device, torch.device) and self.device.type == "xpu"): + from intel_extension_for_pytorch.nn.utils._quantize_convert import WeightOnlyQuantizedLinear + + for name, module in self.named_modules(): + if isinstance(module, WeightOnlyQuantizedLinear): + if module.weight_transposed: + module.qweight.data = module.qweight.t_().contiguous() + module.scales.data = module.scales.t_().contiguous() + module.weight_transposed = False + + self.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.save_pretrained = types.MethodType(save_low_bit, self) + # We conveniently save all the keys of the model to have them on hand, + # so that when using 'low_cpumem load', + # it's not necessary to load the entire model to extract its keys + # and we can avoid gc not triggered potentially. + all_checkpoint_keys = {"all_checkpoint_keys": list(self.state_dict().keys())} + json_file_path = os.path.join(save_directory, "all_checkpoint_keys.json") + with open(json_file_path, "w") as json_file: + json.dump(all_checkpoint_keys, json_file) + if push_to_hub: + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + logger.warning.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + + token = use_auth_token + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + + if token is not None: + kwargs["token"] = token + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + self.quantization_config.save_pretrained(save_directory, **kwargs) diff --git a/neural_compressor/transformers/utils/__init__.py b/neural_compressor/transformers/utils/__init__.py new file mode 100644 index 00000000000..0370d3c0a4e --- /dev/null +++ b/neural_compressor/transformers/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""initialization.""" diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py new file mode 100644 index 00000000000..a9769512b40 --- /dev/null +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -0,0 +1,396 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Intel Neural Compressor Transformers-like Config.""" + +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +from neural_compressor.utils import logger +from neural_compressor.utils.utility import LazyImport + +torch = LazyImport("torch") +transformers = LazyImport("transformers") + +QUANT_CONFIG = "quantize_config.json" + +if transformers.__version__ >= "4.32.0": + from transformers.utils.quantization_config import QuantizationConfigMixin + + QuantizationConfig = QuantizationConfigMixin +else: + from transformers import PretrainedConfig + + QuantizationConfig = PretrainedConfig +from enum import Enum + + +class QuantizationMethod(str, Enum): + GPTQ = "gptq" + RTN = "rtn" + + +class INCQuantizationConfigMixin(QuantizationConfig): + """Mixin class for quantization config.""" + + def update(self, **kwargs): + """Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + def post_init_cpu(self): + r"""Safety checker that arguments are correct.""" + + if self.compute_dtype is not None and self.compute_dtype not in [ + "fp32", + "bf16", + "int8", + ]: + raise ValueError("compute_dtype must be 'fp32', 'bf16', 'int8'.") + elif self.compute_dtype is None: + self.compute_dtype = "fp32" + + if self.bits is None: + self.bits = 4 + elif self.bits is not None and self.bits not in [4, 8]: + raise ValueError(f"Only support quantization to [4, 8] bits but found {self.bits}") + + if self.scale_dtype is not None and self.scale_dtype not in ["fp32", "bf16", "fp16"]: + raise ValueError("scale_dtype must be a string in 'fp32', 'bf16' ") + elif self.scale_dtype is None: + self.scale_dtype = "fp32" + + if not isinstance(self.group_size, int): + raise ValueError("group_size must be a int") + + if not isinstance(self.scheme, str): + raise ValueError("scheme must be a string") + + def post_init_xpu(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + + if self.compute_dtype is not None and self.compute_dtype not in ["fp16"]: + raise ValueError("compute_dtype must be 'fp16'.") + elif self.compute_dtype is None: + self.compute_dtype = "fp16" + + if self.bits is None: + self.bits = 4 + elif self.bits not in [4]: + raise ValueError(f"Only support quantization to [4] bits but found {self.bits}") + + if self.weight_dtype is None: + self.weight_dtype = "int4_fullrange" + elif self.weight_dtype == "int4": + self.weight_dtype = "int4_fullrange" + elif self.weight_dtype not in [ + "int4_fullrange", + ]: + raise ValueError(f"weight_dtype must be a string in 'int4_fullrange', but get {self.weight_dtype}.") + + if self.scale_dtype is not None and self.scale_dtype not in ["fp16"]: + raise ValueError("scale_dtype must be a string in 'fp16'") + elif self.scale_dtype is None: + self.scale_dtype = "fp16" + + if not isinstance(self.group_size, int): + raise ValueError("group_size must be a int") + + if self.scheme not in ["sym"]: + raise ValueError("scheme: {} is not support, only support 'sym' now!".format(self.scheme)) + + def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + """Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + # set tokenizer to None due to it doesn't support write to json + if hasattr(self, "tokenizer"): + self.tokenizer = None + if hasattr(self, "calib_dataloader"): + self.calib_dataloader = None + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(use_diff=use_diff)) + + def remove_redundant_parameters(self): + remove_parameters = [ + "calib_dataloader", + "dataset", + "calib_func", + "calib_iters", + "calib_len", + "double_quant_scale_dtype", + "use_double_quant", + "mse_range", + "scheme", + "tokenizer", + "use_ggml", + "use_neural_speed", + "use_quant", + "layer_wise", + "blocksize", + "nsamples", + "max_input_length", + "static_groups", + "lr", + "minmax_lr", + "iters", + "use_quant_input", + "device", + "calib_dataset", + "calib_pad_val", + "calib_shuffle", + "calib_padding", + "example_inputs", + "excluded_precisions", + "op_name_dict", + "op_type_dict", + "train_dataloader", + "train_func", + "train_iters", + "train_len", + "train_padding", + "train_dataset", + "train_pad_val", + "train_shuffle", + "train_batch_size", + ] + for parameter in remove_parameters: + if hasattr(self, parameter): + delattr(self, parameter) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + push_to_hub: bool = False, + **kwargs, + ): + """Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, QUANT_CONFIG) + + self.to_json_file(output_config_file, use_diff=False) + logger.info(f"Configuration saved in {output_config_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token", None), + ) + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cf = kwargs.pop("_configuration_file", QUANT_CONFIG) + return super().get_config_dict(pretrained_model_name_or_path, _configuration_file=cf, **kwargs) + + +class RtnConfig(INCQuantizationConfigMixin): + def __init__( + self, + bits: int = 4, + group_size: int = 32, + compute_dtype: Any = None, + scale_dtype: Any = None, + sym: bool = True, + use_layer_wise: bool = False, + **kwargs, + ): + self.quant_method = QuantizationMethod.RTN + self.bits = bits + self.compute_dtype = compute_dtype + self.weight_dtype = "int4" if self.bits == 4 else "int8" + self.scale_dtype = scale_dtype + self.group_size = group_size + self.use_layer_wise = use_layer_wise + self.sym = sym + self.scheme = "sym" if self.sym else "asym" + + # "transformer.output_layer" for chatglm series model. + # "embed_out" for dolly v2 series model. + self.modules_to_not_convert = kwargs.get( + "modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"] + ) + self.device = kwargs.get("device", "auto") + if self.use_layer_wise: + self.model_path = kwargs("model_path", None) + if self.model_path is None: + raise AssertionError( + "model_path is necessary if you would like to use_layer_wise for weight only quantization." + ) + + def to_diff_dict(self) -> Dict[str, Any]: + """Removes all attributes from config which correspond to the default config attributes + for better readability and serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = RtnConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +class GPTQConfig(INCQuantizationConfigMixin): + def __init__( + self, + bits: int = 4, + tokenizer: Any = None, + dataset: str = "NeelNanda/pile-10k", + batch_size: int = 8, + group_size: int = 32, + compute_dtype: Any = None, + scale_dtype: Any = None, + sym: bool = True, + blocksize: int = 128, + damp_percent: float = 0.1, + desc_act: bool = False, + n_samples: int = 128, + seq_len: int = 2048, + static_groups: bool = False, + use_mse_search: bool = False, + true_sequential: bool = False, + use_layer_wise: bool = False, + **kwargs, + ): + + self.quant_method = QuantizationMethod.GPTQ + self.bits = bits + self.tokenizer = tokenizer + self.dataset = dataset + self.batch_size = batch_size + self.compute_dtype = compute_dtype + self.weight_dtype = "int4" if self.bits == 4 else "int8" + self.scale_dtype = scale_dtype + self.sym = sym + self.blocksize = blocksize + self.n_samples = n_samples + self.group_size = group_size + self.damp_percent = damp_percent + self.desc_act = desc_act + self.static_groups = static_groups + self.use_mse_search = use_mse_search + self.true_sequential = true_sequential + self.use_layer_wise = use_layer_wise + self.seq_len = seq_len + self.modules_to_not_convert = kwargs.get( + "modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"] + ) + self.device = kwargs.get("device", "auto") + self.scheme = "sym" if self.sym else "asym" + if self.use_layer_wise: + self.model_path = kwargs("model_path", None) + if self.model_path is None: + raise AssertionError( + "model_path is necessary if you would like to use_layer_wise for weight only quantization." + ) + + if isinstance(compute_dtype, torch.dtype): + self.compute_dtype = compute_dtype + else: + self.compute_dtype = compute_dtype + + if isinstance(scale_dtype, torch.dtype): + self.scale_dtype = scale_dtype + else: + self.scale_dtype = scale_dtype + + self.post_init_gptq() + + def post_init_gptq(self): + r"""Safety checker that arguments are correct.""" + + if self.bits not in [4, 8]: + raise ValueError(f"Only support quantization to [4, 8] bits but found {self.bits}") + + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + + def to_diff_dict(self) -> Dict[str, Any]: + """Removes all attributes from config which correspond to the default config attributes + for better readability and serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = GPTQConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict