diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index a4dca68c235..41a02e0d460 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -173,22 +173,19 @@ Large language models (LLMs) have shown exceptional performance across various t |:--------------:|:----------:| | RTN | ✔ | | AWQ | ✕ | -| GPTQ | ✕ | +| GPTQ | ✔ | | TEQ | ✕ | ### Example ```python from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model -fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) +fp32_model = load_empty_model(model_name_or_path, torchscript=True) conf = PostTrainingQuantConfig( approach="weight_only", recipes={ "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, "rtn_args": {"enable_full_range": True}, }, ) @@ -201,6 +198,7 @@ q_model = quantization.fit( ) ouput_dir = "./saved_model" q_model.save(ouput_dir) +q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True) ``` ## Reference diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py index 149a030af09..82443f73492 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py @@ -219,7 +219,7 @@ def skip(*args, **kwargs): tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True) model = model.eval() @@ -294,7 +294,8 @@ def skip(*args, **kwargs): dataloader=calib_dataloader, nsamples = args.nsamples, use_max_length = args.use_max_length, - pad_max_length = args.pad_max_length + pad_max_length = args.pad_max_length, + device = DEV, ) results = lm_evaluate( diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 0d64c5fd74c..03bca5c6dda 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3502,13 +3502,13 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): ): from .torch_utils.layer_wise_quant import LayerWiseQuant - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model._model.path smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False) alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5) - assert ( - model_path is not None - ), "the layer_wise_quant_args should have args model_path to load the weight of model." - device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") + # device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") + assert model_path is not None, "The model_path should not be None." + device = self.device lw_quant = LayerWiseQuant( q_model._model, model_path, @@ -4541,14 +4541,12 @@ def rtn_quantize(self, model, tune_cfg): # for layer_wise quant mode recipe_cfgs = tune_cfg.get("recipe_cfgs", None) if recipe_cfgs.get("layer_wise_quant", False): - from neural_compressor.config import options - - from .torch_utils.layer_wise_quant.utils import _get_path, load_module + from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module - lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir") - os.makedirs(lwq_workspace, exist_ok=True) - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) - assert model_path, "model_path should specify in layer_wise_quant_args." + os.makedirs(LWQ_WORKSPACE, exist_ok=True) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model.path + assert model_path, "model_path should not be None." model_path = _get_path(model_path) for key, config in tune_cfg["op"].items(): @@ -4584,7 +4582,7 @@ def rtn_quantize(self, model, tune_cfg): # save and clean weight from .torch_utils.layer_wise_quant.utils import clean_module_weight - torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt")) + torch.save(m.state_dict(), os.path.join(LWQ_WORKSPACE, f"{op_name}.pt")) clean_module_weight(m) set_module(model, op_name, m) if recipe_cfgs.get("layer_wise_quant", False): @@ -4619,6 +4617,23 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ... } """ + # for layer_wise quant mode + recipe_cfgs = tune_cfg.get("recipe_cfgs", None) + model_path = None + layer_wise = False + if recipe_cfgs.get("layer_wise_quant", False): + layer_wise = True + from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks + + os.makedirs(LWQ_WORKSPACE, exist_ok=True) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model.path + assert model_path, "model_path should not be None." + model_path = _get_path(model_path) + lwq_handles = register_weight_hooks( + model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE + ) + weight_config = {} for key, config in tune_cfg["op"].items(): op_name, op_type = key @@ -4643,7 +4658,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ) # tune_cfg => weight_config model, quantization_perm = gptq_quantize( - model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device + model, + weight_config, + dataloader, + nsamples, + use_max_length, + pad_max_length, + self.device, + layer_wise, + model_path, ) return model, quantization_perm diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index b2710558e57..daaa08df8e3 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import math import random import re @@ -175,6 +176,7 @@ def __init__( use_max_length=True, pad_max_length=2048, device=None, + layer_wise=False, ): """ Args: @@ -215,9 +217,13 @@ def __init__( self.check_layer_config() # device - self.device = model.device + self.device = device + if str(self.model.device).startswith("cuda"): + self.device = self.model.device self.is_ready = False + self.layer_wise = layer_wise + # dataloader self.use_max_length = use_max_length self.pad_max_length = pad_max_length @@ -438,11 +444,13 @@ def forward(layer, *args, **kwargs): raise ValueError # Step1: fetch the embeddings and other layers before the transformer stack. - for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): - embedding_layer = embedding_layer.to(self.device) + if not self.layer_wise: + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer = embedding_layer.to(self.device) # Step2: modify the first transformer block's forward function to obtain inputs for calibration - self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) forward_cache = self.gptq_related_blocks["transformers"][0].forward self.gptq_related_blocks["transformers"][0].forward = partial( forward, self.gptq_related_blocks["transformers"][0] @@ -451,7 +459,8 @@ def forward(layer, *args, **kwargs): # Step3: run forward to obtain calibration datasets logger.info("Collecting calibration inputs...") for batch in tqdm(self.dataloader): - batch = move_input_to_device(batch, self.device) + if not self.layer_wise: + batch = move_input_to_device(batch, self.device) try: if isinstance(batch, tuple) or isinstance(batch, list): self.model(batch[0]) @@ -473,9 +482,10 @@ def forward(layer, *args, **kwargs): # Step 4: restore original forward function, relocate layers back to cpu. self.gptq_related_blocks["transformers"][0].forward = forward_cache - self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() - for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): - embedding_layer.to(self.device) + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer.to(self.device) torch.cuda.empty_cache() # end logger.info("GPTQ quantization prepared.") @@ -501,7 +511,7 @@ def update_blockwise_hidden_states(self, outs): self.cache_positional_arguments[0] = outs[:] @torch.no_grad() - def execute_quantization(self, means=None, stds=None): + def execute_quantization(self, means=None, stds=None, model_path=None): """Run quantization.""" # Step1: prepare quantization (calibration datasets) @@ -513,7 +523,11 @@ def execute_quantization(self, means=None, stds=None): tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + if not self.layer_wise: + # if we do not apply layer-wise feature, we still place the entire block on the GPU + transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + else: + transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -534,8 +548,16 @@ def execute_quantization(self, means=None, stds=None): # weight_config_this_layer = self.weight_config.get( # self.get_full_layer_name(layer_name, block_idx), None # ) - weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) - gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name]) + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + weight_config_this_layer = self.get_layer_config(full_layer_name) + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() gptq_for_this_block[layer_name].quantizer.configure( weight_config_this_layer["wbits"], @@ -555,7 +577,6 @@ def tmp(_, inp, out): for layer_name in sub_layers: handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) idx = self.cache_key_arguments.pop("i") - # import pdb;pdb.set_trace() for j in range(len(self.dataloader)): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) @@ -570,12 +591,44 @@ def tmp(_, inp, out): # ) weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) logger.info(f"Quantizing layer {layer_name}") - scale, zp = gptq_for_this_block[layer_name].fasterquant( + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( + W, blocksize=weight_config_this_layer["block_size"], percdamp=weight_config_this_layer["percdamp"], groupsize=weight_config_this_layer["group_size"], act_order=weight_config_this_layer["act_order"], ) + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import ( + LWQ_WORKSPACE, + clean_module_weight, + load_value, + set_module_tensor_to_device, + ) + + sub_layer = sub_layers[layer_name] + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + for n, p in sub_layer.named_parameters(): + param_name = full_layer_name + "." + n + if n == "weight": + set_module_tensor_to_device(self.model, param_name, self.device, Q) + else: + value = load_value(self.model, param_name, model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + # sub_layer.weight.data = Q + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") + clean_module_weight(sub_layer) + del Q + gc.collect() + else: + sub_layers[layer_name].weight.data = Q gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} if not weight_config_this_layer["sym"]: gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp @@ -594,7 +647,10 @@ def tmp(_, inp, out): out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0] outs.append(out) self.cache_key_arguments["i"] = idx - self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + if self.layer_wise: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block + else: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() del gptq_for_this_block torch.cuda.empty_cache() # iteratively replace the input with output, thus layerwise quantization can continue. @@ -617,10 +673,10 @@ class GPTQ: GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323) """ - def __init__(self, layer): + def __init__(self, layer, W, device="cpu"): self.layer = layer - self.device = self.layer.weight.device - W = layer.weight.data.clone() + self.device = device + # W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): @@ -661,8 +717,9 @@ def add_batch(self, inp, out): # self.H += 2 / self.nsamples * inp.matmul(inp.t()) self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix - def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): - W = self.layer.weight.data.clone() + def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): + # W = self.layer.weight.data.clone() + weight_shape, weight_dtype = W.shape, W.data.dtype if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): @@ -740,7 +797,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") # logger.info(f"{torch.sum(Losses)}") - if self.device != torch.device("cpu"): + if str(self.device).startswith("cuda"): torch.cuda.synchronize() logger.info(f"time {(time.time() - tick)}") logger.info(f"error {torch.sum(Losses).item()}") @@ -751,7 +808,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals if isinstance(self.layer, transformers.Conv1D): Q = Q.t() - self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + Q = Q.reshape(weight_shape).to(weight_dtype) if DEBUG: logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") @@ -760,7 +818,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals zero.append(self.quantizer.zero) scale = torch.cat(scale, dim=1) zero = torch.cat(zero, dim=1) - return scale, zero + return scale, zero, Q def free(self): if DEBUG: diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py index 6f01b1288bf..f347da31bb9 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """Torch layer-wise quantization module.""" -from .utils import load_shell +from .utils import load_empty_model from .quantize import LayerWiseQuant diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py index 1746ad82140..9e3f8789dad 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py @@ -40,7 +40,7 @@ update_module, ) -TMP_DIR = os.path.join(default_workspace, "layer_wise_quant_tmp_dir") +TMP_DIR = os.path.join(default_workspace, "lwq_tmpdir") def mk_tmp_dir(): @@ -92,7 +92,7 @@ def __init__( alpha=0.5, ): """Init LayerWiseQuant.""" - # self.q_model = load_shell(pretrained_model_name_or_path, cls) + # self.q_model = load_empty_model(pretrained_model_name_or_path, cls) self.q_model = q_model self.fp32_model = deepcopy(self.q_model) self.path = _get_path(pretrained_model_name_or_path) diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py index e932c40480a..8bd3d32d320 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -25,7 +25,7 @@ torch = LazyImport("torch") from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device -from transformers import AutoConfig +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass from ....config import options @@ -107,7 +107,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): return file_path -def load_shell(pretrained_model_name_or_path, cls, **kwargs): +def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): """Load a empty model.""" is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: # pragma: no cover @@ -124,6 +124,7 @@ def load_shell(pretrained_model_name_or_path, cls, **kwargs): model = cls(config) model.tie_weights() model.eval() + model.path = pretrained_model_name_or_path return model @@ -223,7 +224,10 @@ def load_module(model, module_name, path, device="cpu"): set_module_tensor_to_device(model, param_name, device, value) -def register_weight_hooks(model, path, device="cpu", clean_weight=True): +def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None): + if saved_path: + os.makedirs(saved_path, exist_ok=True) + def forward_pre_hook(name): def hook(module, input): state_dict = None @@ -241,6 +245,9 @@ def hook(module, input): def forward_hook(name): def hook(module, input, output): + if saved_path: + file_path = os.path.join(saved_path, f"{name}.pt") + torch.save(module.state_dict(), file_path) clean_module_weight(module) return hook diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index b9376a72c7f..7ba86eaa344 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -470,15 +470,27 @@ def rtn_quantize( def gptq_quantize( - model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None + model, + weight_config={}, + dataloader=None, + nsamples=128, + use_max_length=True, + pad_max_length=2048, + device=None, + layer_wise=False, + model_path=None, ): """Run weight-only quantization with.""" # TODO: unify weight_config keys, add docstring, and support default config assert isinstance(model, torch.nn.Module), "only support torch module" + if layer_wise: + assert model_path is not None, "model_path should not be None when use layer_wise mode" from .gptq import GPTQuantizer - gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device) - fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization() + gptq_quantizer = GPTQuantizer( + model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise + ) + fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 95f273eeab5..eeada402f35 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -356,7 +356,8 @@ def save(self, root=None): if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")): state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt")) model_path = _get_path( - self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path") + # self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path") + self._model.path ) for n, p in module.named_parameters(): param_name = name + "." + n diff --git a/test/algorithm/test_layer_wise_quant.py b/test/algorithm/test_layer_wise_quant.py index e25036bb021..2eef1e89bd9 100644 --- a/test/algorithm/test_layer_wise_quant.py +++ b/test/algorithm/test_layer_wise_quant.py @@ -8,14 +8,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model from neural_compressor.utils.pytorch import load class TestLayerWise(unittest.TestCase): def test_layer_wise(self): model_name_or_path = "facebook/opt-125m" - fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + fp32_model = load_empty_model(model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -65,7 +65,7 @@ def test_util(self): ) model_name_or_path = "facebook/opt-125m" - model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + model = load_empty_model(model_name_or_path, torchscript=True) children = get_children(model) named_children = get_named_children(model) self.assertEqual(children, [v for k, v in named_children]) diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index c8840123382..3735f0b2581 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -1,21 +1,22 @@ import shutil import sys import unittest +from copy import deepcopy sys.path.insert(0, "./") import torch from torch.utils.data import DataLoader, Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model from neural_compressor.utils.pytorch import load class TestLayerWise(unittest.TestCase): - def test_layer_wise(self): - model_name_or_path = "facebook/opt-125m" - fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + @classmethod + def setUpClass(self): + self.model_name_or_path = "facebook/opt-125m" + self.fp32_model = load_empty_model(self.model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -29,30 +30,58 @@ def __len__(self): return self.len eval_dataset = TestDataset() - eval_dataloader = DataLoader(eval_dataset, batch_size=8) + self.eval_dataloader = DataLoader(eval_dataset, batch_size=8) + @classmethod + def tearDownClass(cls): + shutil.rmtree("./saved_model", ignore_errors=True) + + def test_rtn_lwq(self): conf = PostTrainingQuantConfig( approach="weight_only", recipes={ "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, + # "layer_wise_quant_args": { + # "model_path": "facebook/opt-125m", + # }, "rtn_args": {"enable_full_range": True}, }, ) q_model = quantization.fit( - fp32_model, + deepcopy(self.fp32_model), conf, - calib_dataloader=eval_dataloader, + calib_dataloader=self.eval_dataloader, eval_func=lambda x: 0.1, ) ouput_dir = "./saved_model" q_model.save(ouput_dir) - load_model = load(ouput_dir, fp32_model, weight_only=True) + load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True) + self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") + + def test_gptq_lwq(self): + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, # 1-8 bits + "group_size": 32, + "scheme": "sym", + "algorithm": "GPTQ", + }, + }, + }, + recipes={ + "gptq_args": {"actorder": True, "mse": True, "perchannel": False}, + "layer_wise_quant": True, + }, + ) + q_model = quantization.fit(deepcopy(self.fp32_model), conf, calib_dataloader=self.eval_dataloader) + ouput_dir = "./saved_model" + q_model.save(ouput_dir) + load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True, layer_wise=True) self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") - shutil.rmtree(ouput_dir) if __name__ == "__main__":