diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/README.md b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/README.md index 4c3b38c0a48..1abe2633ea3 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/README.md +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/README.md @@ -116,13 +116,18 @@ Pytorch and Intel-extension-for-pytorch version for intel GPU > 2.1 are required ```bash pip install -r requirements_GPU.txt pip install transformers==4.38.1 # llama use 4.38.1 -source /opt/intel/oneapi/setvars.sh git clone https://github.com/intel/intel-extension-for-pytorch.git ipex-gpu cd ipex-gpu git submodule update --init --recursive export USE_AOT_DEVLIST='pvc,ats-m150' export BUILD_WITH_CPU=OFF +export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/:$LD_LIBRARY_PATH +export OCL_ICD_VENDORS=/etc/OpenCL/vendors +export CCL_ROOT=${CONDA_PREFIX} +source /opt/intel/oneapi/setvars.sh --force +export LLM_ACC_TEST=1 + python setup.py install ``` diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py index b5dbe20126e..9245d53eb50 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py @@ -200,7 +200,7 @@ tokenizer.save_pretrained(args.output_dir) enable_optimize_transformers = False -opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen"] +opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen", "phi3"] if config.model_type in opt_gpu_model_type_list: enable_optimize_transformers = True diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index 6ce9b49fac8..d1d6912e2fa 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -130,17 +130,16 @@ def convert( if use_layer_wise: from neural_compressor.common.utils import DEFAULT_WORKSPACE - from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, register_weight_hooks + from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module if model_path == "": model_path = model.path assert model_path, "model_path should not be None." model_path = get_path(model_path) - register_weight_hooks(model, model_path, device=device, clean_weight=True) - for name, m in model.named_modules(): - + if use_layer_wise and len(list(m.named_children())) == 0: + load_module(model, name, model_path, device=device) if not isinstance(m, supported_layers): continue if name in weight_config: # pragma: no cover @@ -192,9 +191,6 @@ def convert( logger.debug(f"RTN quantized module:{name, m}") logger.debug(log_msg) - if use_layer_wise: - load_module(model, name, model_path, device=device) - # for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight. if is_transformers_imported(): transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D)) diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 2a6fe5aae64..e52b15a87e0 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -331,11 +331,11 @@ def load_empty_model(pretrained_model_name_or_path, cls=None, **kwargs): if cls.__base__ == _BaseAutoModelClass: config = AutoConfig.from_pretrained(path, **kwargs) with init_empty_weights(): - model = cls.from_config(config) + model = cls.from_config(config, **kwargs) else: # pragma: no cover config = cls.config_class.from_pretrained(path, **kwargs) with init_empty_weights(): - model = cls(config) + model = cls(config, **kwargs) model.tie_weights() model.eval() model.path = pretrained_model_name_or_path diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index 3ec2d0de9a2..cd5b3fe0975 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -134,7 +134,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): (RtnConfig, AwqConfig, TeqConfig, GPTQConfig, AutoRoundConfig), ): logger.info("Applying Weight Only Quantization.") - if use_xpu: + # set use_layer_wise on client + if hasattr(quantization_config, "use_layer_wise"): + import neural_compressor.torch.utils as torch_utils + + process_type = torch_utils.get_processor_type_from_user_config() + if process_type == torch_utils.ProcessorType.Client: + quantization_config.use_layer_wise = True + + if hasattr(quantization_config, "use_layer_wise") and quantization_config.use_layer_wise: + from transformers.dynamic_module_utils import resolve_trust_remote_code + + from neural_compressor.torch import load_empty_model + + trust_remote_code = kwargs.get("trust_remote_code", None) + has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map + has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + ) + + model = load_empty_model(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + if use_cpu: + quantization_config.post_init_cpu() + elif use_xpu: # TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device. kwargs["low_cpu_mem_usage"] = True kwargs["device_map"] = "cpu" diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index d6f90804a52..e81c3295bfa 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -153,7 +153,6 @@ def _replace_linear( "fp16": ipex.quantization.WoqLowpMode.FP16, "int8": ipex.quantization.WoqLowpMode.INT8, } - ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( weight_dtype=weight_dtype[quantization_config.bits], lowp_mode=compute_dtype[quantization_config.compute_dtype], @@ -366,11 +365,6 @@ def convert_to_quantized_model(model, config, device="cpu"): # mapping to INC config dtype = "int4" if config.weight_dtype == "int4_fullrange" else config.weight_dtype - import neural_compressor.torch.utils as torch_utils - - process_type = torch_utils.get_processor_type_from_user_config() - if process_type == torch_utils.ProcessorType.Client: - config.use_layer_wise = True if config.quant_method.value == "rtn": quant_config = RTNConfig( dtype=dtype, @@ -529,6 +523,12 @@ def convert_to_quantized_model(model, config, device="cpu"): if orig_dtype != torch.float32: q_model.to(dtype=orig_dtype) + if config.use_layer_wise and not (q_model.device == device or q_model.device.type == device): + logger.warning( + "Do not convert device to avoid out of memory. Recommend using saved quantized model to inference." + ) + return q_model + return q_model.to(device) diff --git a/test/3x/torch/quantization/weight_only/test_transfomers.py b/test/3x/torch/quantization/weight_only/test_transfomers.py index e9194d9a371..64e9b3a4e9b 100644 --- a/test/3x/torch/quantization/weight_only/test_transfomers.py +++ b/test/3x/torch/quantization/weight_only/test_transfomers.py @@ -115,6 +115,39 @@ def test_save_load(self): loaded_output = loaded_model(dummy_input)[0] assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check." + def test_use_layer_wise(self): + model_name_or_path = self.model_name_or_path + + fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + dummy_input = fp32_model.dummy_inputs["input_ids"] + + # RTN + # use_layer_wise=True + woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True) + woq_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + quantization_config=woq_config, + ) + woq_output = woq_model(dummy_input)[0] + + # save + output_dir = "./transformers_tmp" + woq_model.save_pretrained(output_dir) + + # load + loaded_model = AutoModelForCausalLM.from_pretrained(output_dir) + loaded_output = loaded_model(dummy_input)[0] + assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check." + + # use_layer_wise=False + woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=False) + woq_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + quantization_config=woq_config, + ) + woq_output2 = woq_model(dummy_input)[0] + assert torch.equal(woq_output, woq_output2), "use_layer_wise output should be same. Please double check." + def test_loading_autoawq_model(self): user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model) tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model)