diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh index 49bafea8bd0..458292afa8d 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh @@ -13,6 +13,7 @@ echo "##[section]import check pass" echo "##[group]set up UT env..." export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH pip install -r /neural-compressor/test/3x/torch/requirements.txt +pip install torch==2.5.1 torchvision==0.20.1 # For auto-round pip install pytest-cov pip install pytest-html echo "##[endgroup]" diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 5d0829f5161..1a52f3bb7b7 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -84,7 +84,7 @@ def __init__( enable_torch_compile: bool = None, # mllm is_mllm: bool = False, - quant_nontext_module: Union[str, list] = None, + quant_nontext_module: bool = False, extra_data_dir: str = None, image_processor=None, processor=None, @@ -150,7 +150,7 @@ def __init__( act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. - quant_nontext_module (Union[str, list]): Whether to quantize nontext module. + quant_nontext_module (bool): Whether to quantize nontext module. is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM). extra_data_dir (str): The path for extra data such as images, audio or videos. processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or @@ -383,7 +383,9 @@ def get_mllm_dataloader( template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor ) dataset = template.default_dataset if dataset is None else dataset - if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer)): + if quant_nontext_module or ( + dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer, "cpu", template.model_type) + ): if quant_nontext_module: logger.warning( "Quantitative nontext module is not supported for plain text datasets," @@ -399,7 +401,7 @@ def get_mllm_dataloader( truncation = False gradient_accumulate_steps = batch_size * gradient_accumulate_steps batch_size = 1 - + seed = 42 # The seed is fixed to 42 in transformers seqlen = 2048 if seqlen is None else seqlen # set text only calibration default args truncation = True if truncation is None else truncation dataset = dataset.replace(" ", "") diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 534e848ff6c..705f66d509b 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -950,7 +950,7 @@ def __init__( enable_torch_compile: bool = None, # mllm is_mllm: bool = False, - quant_nontext_module: Union[str, list] = None, + quant_nontext_module: bool = False, extra_data_dir: str = None, processor=None, image_processor=None, @@ -994,7 +994,7 @@ def __init__( export_format (str, optional): The format used for exporting the quantized model. Defaults to "itrex". enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. - quant_nontext_module (Union[str, list]): Whether to quantize nontext module. + quant_nontext_module (bool): Whether to quantize nontext module. extra_data_dir (str): The path for extra data such as images, audio or videos. is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM). processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or diff --git a/neural_compressor/transformers/__init__.py b/neural_compressor/transformers/__init__.py index 4eb6a044664..54b0141e21d 100644 --- a/neural_compressor/transformers/__init__.py +++ b/neural_compressor/transformers/__init__.py @@ -23,4 +23,5 @@ AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, + Qwen2VLForConditionalGeneration, ) diff --git a/neural_compressor/transformers/models/__init__.py b/neural_compressor/transformers/models/__init__.py index d951600ca48..4dc24600544 100644 --- a/neural_compressor/transformers/models/__init__.py +++ b/neural_compressor/transformers/models/__init__.py @@ -13,4 +13,11 @@ # limitations under the License. from .modeling_auto import _BaseINCAutoModelClass -from .modeling_auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM +from .modeling_auto import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + Qwen2VLForConditionalGeneration, + MllamaForConditionalGeneration, + LlavaForConditionalGeneration, +) diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index 1226fd21d97..6c4a0ceda98 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -354,24 +354,27 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): else: commit_hash = getattr(config, "_commit_hash", None) - has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map - - has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys() - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, - pretrained_model_name_or_path, - has_local_code, - has_remote_code, - ) - if has_remote_code and trust_remote_code: - class_ref = config.auto_map[cls.ORIG_MODEL.__name__] - model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) - if os.path.isdir(pretrained_model_name_or_path): - model_class.register_for_auto_class(cls.ORIG_MODEL.__name__) - else: - cls.ORIG_MODEL.register(config.__class__, model_class, exist_ok=True) - elif type(config) in cls.ORIG_MODEL._model_mapping.keys(): - model_class = _get_model_class(config, cls.ORIG_MODEL._model_mapping) + if "AutoModel" in cls.ORIG_MODEL.__name__: + has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map + has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys() + + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + ) + if has_remote_code and trust_remote_code: + class_ref = config.auto_map[cls.ORIG_MODEL.__name__] + model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig) + if os.path.isdir(pretrained_model_name_or_path): + model_class.register_for_auto_class(cls.ORIG_MODEL.__name__) + else: + cls.ORIG_MODEL.register(config.__class__, model_class, exist_ok=True) + elif type(config) in cls.ORIG_MODEL._model_mapping.keys(): + model_class = _get_model_class(config, cls.ORIG_MODEL._model_mapping) + else: + model_class = cls.ORIG_MODEL # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. @@ -747,3 +750,15 @@ class AutoModel(_BaseINCAutoModelClass): class AutoModelForSeq2SeqLM(_BaseINCAutoModelClass): ORIG_MODEL = transformers.AutoModelForSeq2SeqLM + + +class Qwen2VLForConditionalGeneration(_BaseINCAutoModelClass): + ORIG_MODEL = transformers.Qwen2VLForConditionalGeneration + + +class MllamaForConditionalGeneration(_BaseINCAutoModelClass): + ORIG_MODEL = transformers.MllamaForConditionalGeneration + + +class LlavaForConditionalGeneration(_BaseINCAutoModelClass): + ORIG_MODEL = transformers.LlavaForConditionalGeneration diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index 84b49cfe24b..0ab18b91b0c 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -17,6 +17,7 @@ import json import math import os +import re import types from datasets import load_dataset @@ -33,11 +34,16 @@ convert, prepare, ) -from neural_compressor.torch.utils import is_ipex_available +from neural_compressor.torch.utils import is_ipex_available, is_package_available if is_ipex_available(): import intel_extension_for_pytorch as ipex +if is_package_available("auto_round"): + import auto_round + import transformers + from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear as auto_round_woq_linear + from typing import Union torch = LazyImport("torch") @@ -126,10 +132,12 @@ def _replace_linear( if ( isinstance(module, torch.nn.Linear) or isinstance(module, INCWeightOnlyLinear) - or (is_ipex_available() and isinstance(module, ipex.nn.utils._weight_prepack._IPEXLinear)) + or (is_package_available("auto_round") and isinstance(module, auto_round_woq_linear)) ) and (name not in modules_to_not_convert): # Check if the current key is not in the `modules_to_not_convert` - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and not any( + re.match(pattern, ".".join(current_key_name)) for pattern in modules_to_not_convert + ): in_features = module.in_features out_features = module.out_features if device == "cpu" or device == torch.device("cpu") or device == "auto": @@ -475,6 +483,54 @@ def convert_to_quantized_model(model, config, device="cpu"): run_fn(model, *run_args) model = convert(model) elif config.quant_method.value == "autoround": + if config.is_vlm is True: + from transformers import AutoProcessor, AutoTokenizer + + from neural_compressor.torch.algorithms.weight_only.autoround import ( + get_mllm_dataloader as get_autoround_dataloader, + ) + + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + processor = AutoProcessor.from_pretrained(model.config._name_or_path, trust_remote_code=True) + ( + dataloader, + template, + config.truncation, + config.batch_size, + config.gradient_accumulate_steps, + config.seq_len, + config.n_samples, + ) = get_autoround_dataloader( + template=None, + model=model, + tokenizer=tokenizer, + image_processor=None, + dataset=config.dataset, + extra_data_dir=None, + seqlen=config.seq_len, + batch_size=config.batch_size, + split=None, + apply_template=None, + truncation=False, + nsamples=config.n_samples, + seed=42, + gradient_accumulate_steps=config.gradient_accumulate_steps, + quant_nontext_module=config.quant_nontext_module, + processor=processor, + ) + else: + from neural_compressor.torch.algorithms.weight_only.autoround import ( + get_dataloader as get_autoround_dataloader, + ) + + dataloader = get_autoround_dataloader( + tokenizer=config.tokenizer, + seqlen=config.seq_len, + dataset_name=config.dataset, + seed=42, + bs=config.batch_size, + nsamples=config.n_samples, + ) quant_config = AutoRoundConfig( dtype=dtype, bits=config.bits, @@ -486,24 +542,59 @@ def convert_to_quantized_model(model, config, device="cpu"): seqlen=config.seq_len, nsamples=config.n_samples, iters=config.iters, + batch_size=config.batch_size, scale_dtype=config.scale_dtype, use_layer_wise=config.use_layer_wise, + # vlm arguments + is_mllm=config.is_vlm, + quant_nontext_module=config.quant_nontext_module, + truncation=config.truncation, + gradient_accumulate_steps=config.gradient_accumulate_steps, + export_format=config.export_format, ) + + # vlm set non-text module config + if config.is_vlm is True: + from neural_compressor.torch.utils.utility import ( + find_matching_blocks, + get_layer_names_in_block, + get_multimodal_block_names, + ) + + def set_nontext_module_config(model, to_quant_block_names, config): + all_block_list = get_multimodal_block_names(model, quant_vision=True) + all_block_set = set(tuple(block) for block in all_block_list) + quant_block_set = set(tuple(block) for block in to_quant_block_names) + set_to_full_prec = list(all_block_set - quant_block_set) + set_to_full_prec = get_layer_names_in_block(model, to_quant_block_names=set_to_full_prec) + for name in set_to_full_prec: + config.modules_to_not_convert.append(name) + + # skip layers not in blocks + config.modules_to_not_convert.append("model.vision_embed_tokens.img_projection*") + config.modules_to_not_convert.append("transformer.visual.attn_pool.*_proj") + config.modules_to_not_convert.append("model.mm_projector*") + config.modules_to_not_convert.append("multi_modal_projector") + config.modules_to_not_convert.append("visual.merger") + + all_blocks = get_multimodal_block_names(model, quant_config.quant_nontext_module) + to_quant_block_names = find_matching_blocks(model, all_blocks, quant_config.to_quant_block_names) + set_nontext_module_config(model, to_quant_block_names, config) + + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D): + if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: + config.modules_to_not_convert.append(n) + print( + f"{n} will not be quantized due to its shape not being divisible by 32," + " resulting in an exporting issue to autogptq" + ) if config.modules_to_not_convert != []: for module in config.modules_to_not_convert: module_name = ".*" + module quant_config.set_local(module_name, AutoRoundConfig(dtype="fp32")) logger.info(f"Do AutoRound algorithm with config {quant_config}") - from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader as get_autoround_dataloader - dataloader = get_autoround_dataloader( - tokenizer=config.tokenizer, - seqlen=config.seq_len, - dataset_name=config.dataset, - seed=42, - bs=config.batch_size, - nsamples=config.n_samples, - ) run_fn = run_fn_for_autoround run_args = (dataloader,) model = prepare(model=model, quant_config=quant_config) diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py index 3e72de3c330..00fe1ec0fbf 100644 --- a/neural_compressor/transformers/utils/quantization_config.py +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -543,6 +543,12 @@ def __init__( iters: int = 200, use_layer_wise: bool = None, quant_lm_head: bool = False, + # vlm arguments + is_vlm: bool = False, + quant_nontext_module: bool = False, + truncation: bool = False, + gradient_accumulate_steps: int = 1, + export_format="itrex", **kwargs, ): @@ -594,6 +600,13 @@ def __init__( self.use_layer_wise = use_layer_wise self.model_path = kwargs.get("model_path", "") + # vlm arguments + self.is_vlm = is_vlm + self.quant_nontext_module = quant_nontext_module + self.truncation = truncation + self.gradient_accumulate_steps = gradient_accumulate_steps + self.export_format = export_format + def to_diff_dict(self) -> Dict[str, Any]: """Removes all attributes from config which correspond to the default config attributes for better readability and serializes to a Python dictionary. diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 46a71cc2cfd..e7664d9ad35 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -238,13 +238,13 @@ def test_mllm(self): image_processor=None, dataset="liuhaotian/llava_conv_58k", extra_data_dir=None, - seqlen=512, + seqlen=32, batch_size=1, split=None, apply_template=None, truncation=False, seed=42, - nsamples=5, + nsamples=1, gradient_accumulate_steps=1, quant_nontext_module=False, processor=processor, @@ -253,9 +253,9 @@ def test_mllm(self): bits=4, group_size=128, is_mllm=True, - nsamples=5, + nsamples=1, batch_size=batch_size, - iters=2, + iters=1, seqlen=seqlen, quant_nontext_module=False, truncation=truncation, diff --git a/test/3x/torch/quantization/weight_only/test_transfomers.py b/test/3x/torch/quantization/weight_only/test_transformers.py similarity index 78% rename from test/3x/torch/quantization/weight_only/test_transfomers.py rename to test/3x/torch/quantization/weight_only/test_transformers.py index 83f6b664da0..684bb2c14e4 100644 --- a/test/3x/torch/quantization/weight_only/test_transfomers.py +++ b/test/3x/torch/quantization/weight_only/test_transformers.py @@ -10,6 +10,7 @@ from neural_compressor.torch.utils import get_ipex_version from neural_compressor.transformers import ( AutoModelForCausalLM, + Qwen2VLForConditionalGeneration, AutoRoundConfig, AwqConfig, GPTQConfig, @@ -19,6 +20,12 @@ ipex_version = get_ipex_version() +try: + import auto_round + + auto_round_installed = True +except ImportError: + auto_round_installed = False class TestTansformersLikeAPI: def setup_class(self): @@ -30,6 +37,7 @@ def setup_class(self): def teardown_class(self): shutil.rmtree("nc_workspace", ignore_errors=True) shutil.rmtree("transformers_tmp", ignore_errors=True) + shutil.rmtree("transformers_vlm_tmp", ignore_errors=True) def test_quantization_for_llm(self): model_name_or_path = self.model_name_or_path @@ -208,3 +216,49 @@ def test_loading_autoawq_model(self): else: target_text = ["One day, the little girl in the back of my mind will say, “I’m so glad you’"] assert gen_text == target_text, "loading autoawq quantized model failed." + + @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") + def test_vlm(self): + model_name = "Qwen/Qwen2-VL-2B-Instruct" + from neural_compressor.transformers import Qwen2VLForConditionalGeneration + from neural_compressor.transformers import AutoModelForCausalLM + woq_config = AutoRoundConfig( + bits=4, + group_size=128, + is_vlm=True, + dataset="NeelNanda/pile-10k", + iters=1, + n_samples=1, + seq_len=32, + batch_size=1, + ) + + woq_model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True) + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + from intel_extension_for_pytorch.nn.utils._quantize_convert import WeightOnlyQuantizedLinear + else: + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear + assert isinstance(woq_model.model.layers[0].self_attn.k_proj, WeightOnlyQuantizedLinear), "replacing model failed." + + #save + woq_model.save_pretrained("transformers_vlm_tmp") + + #load + loaded_model = Qwen2VLForConditionalGeneration.from_pretrained("transformers_vlm_tmp") + assert isinstance(loaded_model.model.layers[0].self_attn.k_proj, WeightOnlyQuantizedLinear), "loaing model failed." + + # phi-3-vision-128k-instruct, disable as CI consumes too much time + # woq_config = AutoRoundConfig( + # bits=4, + # group_size=128, + # is_vlm=True, + # dataset="liuhaotian/llava_conv_58k", + # iters=2, + # n_samples=5, + # seq_len=64, + # batch_size=1, + # ) + # model_name = "microsoft/Phi-3-vision-128k-instruct" + # woq_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True, attn_implementation='eager') + # assert isinstance(woq_model.model.layers[0].self_attn.o_proj, WeightOnlyQuantizedLinear), "quantizaion failed." diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index 5b97060f9f8..d9697dcac5e 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -7,4 +7,5 @@ peft prettytable psutil pytest +torchvision transformers