From b70312a9afbbc0de7a014e5546cbfaf9b8c26f42 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Fri, 13 Oct 2023 15:57:46 +0800 Subject: [PATCH 01/15] support lwq for gptq Signed-off-by: Guo, Heng --- neural_compressor/adaptor/pytorch.py | 24 ++++-- neural_compressor/adaptor/torch_utils/gptq.py | 81 ++++++++++++++----- .../torch_utils/layer_wise_quant/quantize.py | 2 +- .../torch_utils/layer_wise_quant/utils.py | 9 ++- .../adaptor/torch_utils/weight_only.py | 8 +- test/algorithm/test_lwq_weight_only.py | 48 +++++++++-- 6 files changed, 131 insertions(+), 41 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index a8062aff3ed..26537f04332 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4535,12 +4535,9 @@ def rtn_quantize(self, model, tune_cfg): # for layer_wise quant mode recipe_cfgs = tune_cfg.get("recipe_cfgs", None) if recipe_cfgs.get("layer_wise_quant", False): - from neural_compressor.config import options + from .torch_utils.layer_wise_quant.utils import _get_path, load_module, LWQ_WORKSPACE - from .torch_utils.layer_wise_quant.utils import _get_path, load_module - - lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir") - os.makedirs(lwq_workspace, exist_ok=True) + os.makedirs(LWQ_WORKSPACE, exist_ok=True) model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) assert model_path, "model_path should specify in layer_wise_quant_args." model_path = _get_path(model_path) @@ -4578,7 +4575,7 @@ def rtn_quantize(self, model, tune_cfg): # save and clean weight from .torch_utils.layer_wise_quant.utils import clean_module_weight - torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt")) + torch.save(m.state_dict(), os.path.join(LWQ_WORKSPACE, f"{op_name}.pt")) clean_module_weight(m) set_module(model, op_name, m) if recipe_cfgs.get("layer_wise_quant", False): @@ -4613,6 +4610,19 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ... } """ + # for layer_wise quant mode + recipe_cfgs = tune_cfg.get("recipe_cfgs", None) + model_path = None + layer_wise = False + if recipe_cfgs.get("layer_wise_quant", False): + layer_wise = True + from .torch_utils.layer_wise_quant.utils import _get_path, LWQ_WORKSPACE, register_weight_hooks + os.makedirs(LWQ_WORKSPACE, exist_ok=True) + model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + assert model_path, "model_path should specify in layer_wise_quant_args." + model_path = _get_path(model_path) + lwq_handles = register_weight_hooks(model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE) + weight_config = {} for key, config in tune_cfg["op"].items(): op_name, op_type = key @@ -4637,7 +4647,7 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ) # tune_cfg => weight_config model, quantization_perm = gptq_quantize( - model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device + model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device, layer_wise, model_path ) return model, quantization_perm diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 28fa79766a9..3ed003150eb 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -175,6 +175,7 @@ def __init__( use_max_length=True, pad_max_length=2048, device=None, + layer_wise=False ): """ Args: @@ -215,9 +216,11 @@ def __init__( self.check_layer_config() # device - self.device = model.device + self.device = device self.is_ready = False + self.layer_wise = layer_wise + # dataloader self.use_max_length = use_max_length self.pad_max_length = pad_max_length @@ -226,6 +229,7 @@ def __init__( self.nsamples = nsamples self.prepare_dataloader() + def prepare_dataloader(self): if self.use_max_length: # (Recommend) only take sequence whose length exceeds self.pad_max_length, @@ -429,11 +433,13 @@ def forward(layer, hidden_states, *args, **kwargs): raise ValueError # Step1: fetch the embeddings and other layers before the transformer stack. - for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): - embedding_layer = embedding_layer.to(self.device) + if not self.layer_wise: + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer = embedding_layer.to(self.device) # Step2: modify the first transformer block's forward function to obtain inputs for calibration - self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) forward_cache = self.gptq_related_blocks["transformers"][0].forward self.gptq_related_blocks["transformers"][0].forward = partial( forward, self.gptq_related_blocks["transformers"][0] @@ -460,9 +466,10 @@ def forward(layer, hidden_states, *args, **kwargs): # Step 4: restore original forward function, relocate layers back to cpu. self.gptq_related_blocks["transformers"][0].forward = forward_cache - self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() - for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): - embedding_layer.to(self.device) + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer.to(self.device) torch.cuda.empty_cache() # end logger.info("GPTQ quantization prepared.") @@ -482,7 +489,7 @@ def gather_single_batch_from_list(self, data_list, idx): return single_batch @torch.no_grad() - def execute_quantization(self, means=None, stds=None): + def execute_quantization(self, means=None, stds=None, model_path=None): """Run quantization.""" # Step1: prepare quantization (calibration datasets) logger.info("Begin ====>") @@ -492,7 +499,7 @@ def execute_quantization(self, means=None, stds=None): tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + transformer_block = self.gptq_related_blocks["transformers"][block_idx]#.to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -513,8 +520,15 @@ def execute_quantization(self, means=None, stds=None): # weight_config_this_layer = self.weight_config.get( # self.get_full_layer_name(layer_name, block_idx), None # ) - weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) - gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name]) + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + weight_config_this_layer = self.get_layer_config(full_layer_name) + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + W = load_value(self.model, full_layer_name + '.weight', model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() gptq_for_this_block[layer_name].quantizer.configure( weight_config_this_layer["wbits"], @@ -549,12 +563,32 @@ def tmp(_, inp, out): # ) weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) logger.info(f"Quantizing layer {layer_name}") - scale, zp = gptq_for_this_block[layer_name].fasterquant( + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + W = load_value(self.model, full_layer_name + '.weight', model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( + W, blocksize=weight_config_this_layer["block_size"], percdamp=weight_config_this_layer["percdamp"], groupsize=weight_config_this_layer["group_size"], act_order=weight_config_this_layer["act_order"], ) + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, load_value, set_module_tensor_to_device, clean_module_weight + sub_layer = sub_layers[layer_name] + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + for n, p in sub_layer.named_parameters(): + param_name = full_layer_name + '.' + n + value = load_value(self.model, param_name, model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + sub_layer.weight.data = Q + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f'/{full_layer_name}.pt') + clean_module_weight(sub_layer) + else: + sub_layers[layer_name].weight.data = Q gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} if not weight_config_this_layer["sym"]: gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp @@ -572,7 +606,10 @@ def tmp(_, inp, out): cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0] self.cache["i"] = idx - self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + if self.layer_wise: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block + else: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() del gptq_for_this_block torch.cuda.empty_cache() # iteratively replace the input with output, thus layerwise quantization can continue. @@ -595,10 +632,10 @@ class GPTQ: GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323) """ - def __init__(self, layer): + def __init__(self, layer, W, device='cpu'): self.layer = layer - self.device = self.layer.weight.device - W = layer.weight.data.clone() + self.device = device + # W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): @@ -639,8 +676,9 @@ def add_batch(self, inp, out): # self.H += 2 / self.nsamples * inp.matmul(inp.t()) self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix - def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): - W = self.layer.weight.data.clone() + def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): + # W = self.layer.weight.data.clone() + weight_shape, weight_dtype = W.shape, W.data.dtype if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): @@ -718,7 +756,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") # logger.info(f"{torch.sum(Losses)}") - if self.device != torch.device("cpu"): + if str(self.device) != "cpu": torch.cuda.synchronize() logger.info(f"time {(time.time() - tick)}") logger.info(f"error {torch.sum(Losses).item()}") @@ -729,7 +767,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals if isinstance(self.layer, transformers.Conv1D): Q = Q.t() - self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + Q = Q.reshape(weight_shape).to(weight_dtype) if DEBUG: logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") @@ -738,7 +777,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals zero.append(self.quantizer.zero) scale = torch.cat(scale, dim=1) zero = torch.cat(zero, dim=1) - return scale, zero + return scale, zero, Q def free(self): if DEBUG: diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py index 1746ad82140..052ed502c17 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py @@ -40,7 +40,7 @@ update_module, ) -TMP_DIR = os.path.join(default_workspace, "layer_wise_quant_tmp_dir") +TMP_DIR = os.path.join(default_workspace, "lwq_tmpdir") def mk_tmp_dir(): diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py index e932c40480a..1bd1d856738 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -223,7 +223,10 @@ def load_module(model, module_name, path, device="cpu"): set_module_tensor_to_device(model, param_name, device, value) -def register_weight_hooks(model, path, device="cpu", clean_weight=True): +def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None): + if saved_path: + os.makedirs(saved_path, exist_ok=True) + def forward_pre_hook(name): def hook(module, input): state_dict = None @@ -241,8 +244,10 @@ def hook(module, input): def forward_hook(name): def hook(module, input, output): + if saved_path: + file_path = os.path.join(saved_path, f"{name}.pt") + torch.save(module.state_dict(), file_path) clean_module_weight(module) - return hook handle = {} diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index b9376a72c7f..7a9ef3da628 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -470,15 +470,17 @@ def rtn_quantize( def gptq_quantize( - model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None + model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None, layer_wise=False, model_path=None ): """Run weight-only quantization with.""" # TODO: unify weight_config keys, add docstring, and support default config assert isinstance(model, torch.nn.Module), "only support torch module" + if layer_wise: + assert model_path is not None, 'model_path should not be None when use layer_wise mode' from .gptq import GPTQuantizer - gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device) - fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization() + gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise) + fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index c8840123382..c6e4eefa975 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -13,9 +13,10 @@ class TestLayerWise(unittest.TestCase): - def test_layer_wise(self): - model_name_or_path = "facebook/opt-125m" - fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + @classmethod + def setUpClass(self): + self.model_name_or_path = "facebook/opt-125m" + self.fp32_model = load_shell(self.model_name_or_path, AutoModelForCausalLM, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -29,8 +30,9 @@ def __len__(self): return self.len eval_dataset = TestDataset() - eval_dataloader = DataLoader(eval_dataset, batch_size=8) + self.eval_dataloader = DataLoader(eval_dataset, batch_size=8) + def test_rtn_lwq(self): conf = PostTrainingQuantConfig( approach="weight_only", recipes={ @@ -43,14 +45,46 @@ def __len__(self): ) q_model = quantization.fit( - fp32_model, + self.fp32_model, conf, - calib_dataloader=eval_dataloader, + calib_dataloader=self.eval_dataloader, eval_func=lambda x: 0.1, ) ouput_dir = "./saved_model" q_model.save(ouput_dir) - load_model = load(ouput_dir, fp32_model, weight_only=True) + load_model = load(ouput_dir, self.fp32_model, weight_only=True) + self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") + shutil.rmtree(ouput_dir) + + def test_gptq_lwq(self): + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, # 1-8 bits + "group_size": 32, + "scheme": "sym", + "algorithm": "GPTQ", + }, + }, + }, + recipes={ + "gptq_args": {"actorder": True, "mse": True, "perchannel": False}, + "layer_wise_quant": True, + "layer_wise_quant_args": { + "model_path": "facebook/opt-125m", + } + }, + ) + q_model = quantization.fit( + self.fp32_model, + conf, + calib_dataloader=self.eval_dataloader + ) + ouput_dir = "./saved_model" + q_model.save(ouput_dir) + load_model = load(ouput_dir, self.fp32_model, weight_only=True) self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") shutil.rmtree(ouput_dir) From 897cee3e212be8ac819579112329600dd65580ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 08:00:15 +0000 Subject: [PATCH 02/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/adaptor/pytorch.py | 19 ++++++++++--- neural_compressor/adaptor/torch_utils/gptq.py | 27 ++++++++++++------- .../torch_utils/layer_wise_quant/utils.py | 1 + .../adaptor/torch_utils/weight_only.py | 16 ++++++++--- test/algorithm/test_lwq_weight_only.py | 8 ++---- 5 files changed, 48 insertions(+), 23 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 26537f04332..25faf1881af 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4535,7 +4535,7 @@ def rtn_quantize(self, model, tune_cfg): # for layer_wise quant mode recipe_cfgs = tune_cfg.get("recipe_cfgs", None) if recipe_cfgs.get("layer_wise_quant", False): - from .torch_utils.layer_wise_quant.utils import _get_path, load_module, LWQ_WORKSPACE + from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module os.makedirs(LWQ_WORKSPACE, exist_ok=True) model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) @@ -4616,12 +4616,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader): layer_wise = False if recipe_cfgs.get("layer_wise_quant", False): layer_wise = True - from .torch_utils.layer_wise_quant.utils import _get_path, LWQ_WORKSPACE, register_weight_hooks + from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks + os.makedirs(LWQ_WORKSPACE, exist_ok=True) model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) assert model_path, "model_path should specify in layer_wise_quant_args." model_path = _get_path(model_path) - lwq_handles = register_weight_hooks(model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE) + lwq_handles = register_weight_hooks( + model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE + ) weight_config = {} for key, config in tune_cfg["op"].items(): @@ -4647,7 +4650,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ) # tune_cfg => weight_config model, quantization_perm = gptq_quantize( - model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device, layer_wise, model_path + model, + weight_config, + dataloader, + nsamples, + use_max_length, + pad_max_length, + self.device, + layer_wise, + model_path, ) return model, quantization_perm diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 3ed003150eb..2727af772a9 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -175,7 +175,7 @@ def __init__( use_max_length=True, pad_max_length=2048, device=None, - layer_wise=False + layer_wise=False, ): """ Args: @@ -229,7 +229,6 @@ def __init__( self.nsamples = nsamples self.prepare_dataloader() - def prepare_dataloader(self): if self.use_max_length: # (Recommend) only take sequence whose length exceeds self.pad_max_length, @@ -499,7 +498,7 @@ def execute_quantization(self, means=None, stds=None, model_path=None): tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - transformer_block = self.gptq_related_blocks["transformers"][block_idx]#.to(self.device) + transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -524,10 +523,11 @@ def execute_quantization(self, means=None, stds=None, model_path=None): weight_config_this_layer = self.get_layer_config(full_layer_name) if self.layer_wise: from ..torch_utils.layer_wise_quant.utils import load_value - W = load_value(self.model, full_layer_name + '.weight', model_path) + + W = load_value(self.model, full_layer_name + ".weight", model_path) else: W = sub_layers[layer_name].weight.data.clone() - + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() gptq_for_this_block[layer_name].quantizer.configure( @@ -565,8 +565,9 @@ def tmp(_, inp, out): logger.info(f"Quantizing layer {layer_name}") if self.layer_wise: from ..torch_utils.layer_wise_quant.utils import load_value + full_layer_name = self.get_full_layer_name(layer_name, block_idx) - W = load_value(self.model, full_layer_name + '.weight', model_path) + W = load_value(self.model, full_layer_name + ".weight", model_path) else: W = sub_layers[layer_name].weight.data.clone() scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( @@ -577,15 +578,21 @@ def tmp(_, inp, out): act_order=weight_config_this_layer["act_order"], ) if self.layer_wise: - from ..torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, load_value, set_module_tensor_to_device, clean_module_weight + from ..torch_utils.layer_wise_quant.utils import ( + LWQ_WORKSPACE, + clean_module_weight, + load_value, + set_module_tensor_to_device, + ) + sub_layer = sub_layers[layer_name] full_layer_name = self.get_full_layer_name(layer_name, block_idx) for n, p in sub_layer.named_parameters(): - param_name = full_layer_name + '.' + n + param_name = full_layer_name + "." + n value = load_value(self.model, param_name, model_path) set_module_tensor_to_device(self.model, param_name, self.device, value) sub_layer.weight.data = Q - torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f'/{full_layer_name}.pt') + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") clean_module_weight(sub_layer) else: sub_layers[layer_name].weight.data = Q @@ -632,7 +639,7 @@ class GPTQ: GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323) """ - def __init__(self, layer, W, device='cpu'): + def __init__(self, layer, W, device="cpu"): self.layer = layer self.device = device # W = layer.weight.data.clone() diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py index 1bd1d856738..819e69fbc70 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -248,6 +248,7 @@ def hook(module, input, output): file_path = os.path.join(saved_path, f"{name}.pt") torch.save(module.state_dict(), file_path) clean_module_weight(module) + return hook handle = {} diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 7a9ef3da628..7ba86eaa344 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -470,16 +470,26 @@ def rtn_quantize( def gptq_quantize( - model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None, layer_wise=False, model_path=None + model, + weight_config={}, + dataloader=None, + nsamples=128, + use_max_length=True, + pad_max_length=2048, + device=None, + layer_wise=False, + model_path=None, ): """Run weight-only quantization with.""" # TODO: unify weight_config keys, add docstring, and support default config assert isinstance(model, torch.nn.Module), "only support torch module" if layer_wise: - assert model_path is not None, 'model_path should not be None when use layer_wise mode' + assert model_path is not None, "model_path should not be None when use layer_wise mode" from .gptq import GPTQuantizer - gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise) + gptq_quantizer = GPTQuantizer( + model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise + ) fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index c6e4eefa975..ee034aee2b9 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -74,14 +74,10 @@ def test_gptq_lwq(self): "layer_wise_quant": True, "layer_wise_quant_args": { "model_path": "facebook/opt-125m", - } + }, }, ) - q_model = quantization.fit( - self.fp32_model, - conf, - calib_dataloader=self.eval_dataloader - ) + q_model = quantization.fit(self.fp32_model, conf, calib_dataloader=self.eval_dataloader) ouput_dir = "./saved_model" q_model.save(ouput_dir) load_model = load(ouput_dir, self.fp32_model, weight_only=True) From adfc8a2fc19d4d4a385b3d32dcdb4877e3ee4b65 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Mon, 16 Oct 2023 13:21:43 +0800 Subject: [PATCH 03/15] fix cuda error Signed-off-by: Guo, Heng --- neural_compressor/adaptor/torch_utils/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 2727af772a9..c077c19eb27 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -763,7 +763,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") # logger.info(f"{torch.sum(Losses)}") - if str(self.device) != "cpu": + if str(self.device).startswith('cuda'): torch.cuda.synchronize() logger.info(f"time {(time.time() - tick)}") logger.info(f"error {torch.sum(Losses).item()}") From dae9d50c71f8df4c353cbf9e8fdb162669d7d3a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 05:22:51 +0000 Subject: [PATCH 04/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/adaptor/torch_utils/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index c077c19eb27..939c35620ea 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -763,7 +763,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") # logger.info(f"{torch.sum(Losses)}") - if str(self.device).startswith('cuda'): + if str(self.device).startswith("cuda"): torch.cuda.synchronize() logger.info(f"time {(time.time() - tick)}") logger.info(f"error {torch.sum(Losses).item()}") From faeb54d3f9ef3a26247b1d6b71adc40c78f8f531 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Tue, 17 Oct 2023 13:47:28 +0800 Subject: [PATCH 05/15] new api Signed-off-by: Guo, Heng --- neural_compressor/adaptor/pytorch.py | 18 ++++++++----- .../torch_utils/layer_wise_quant/utils.py | 5 ++-- neural_compressor/model/torch_model.py | 3 ++- test/algorithm/test_lwq_weight_only.py | 26 +++++++++---------- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 25faf1881af..c9bc8c572d5 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3496,13 +3496,15 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): ): from .torch_utils.layer_wise_quant import LayerWiseQuant - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model._model.path smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False) alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5) + # device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") assert ( model_path is not None - ), "the layer_wise_quant_args should have args model_path to load the weight of model." - device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") + ), "The model_path should not be None." + device = self.device lw_quant = LayerWiseQuant( q_model._model, model_path, @@ -4538,8 +4540,9 @@ def rtn_quantize(self, model, tune_cfg): from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module os.makedirs(LWQ_WORKSPACE, exist_ok=True) - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) - assert model_path, "model_path should specify in layer_wise_quant_args." + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model.path + assert model_path, "model_path should not be None." model_path = _get_path(model_path) for key, config in tune_cfg["op"].items(): @@ -4619,8 +4622,9 @@ def gptq_quantize(self, model, tune_cfg, dataloader): from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks os.makedirs(LWQ_WORKSPACE, exist_ok=True) - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) - assert model_path, "model_path should specify in layer_wise_quant_args." + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model.path + assert model_path, "model_path should not be None." model_path = _get_path(model_path) lwq_handles = register_weight_hooks( model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py index 819e69fbc70..306e4d27ecf 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -25,7 +25,7 @@ torch = LazyImport("torch") from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device -from transformers import AutoConfig +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass from ....config import options @@ -107,7 +107,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): return file_path -def load_shell(pretrained_model_name_or_path, cls, **kwargs): +def load_shell(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): """Load a empty model.""" is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: # pragma: no cover @@ -124,6 +124,7 @@ def load_shell(pretrained_model_name_or_path, cls, **kwargs): model = cls(config) model.tie_weights() model.eval() + model.path = pretrained_model_name_or_path return model diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 95f273eeab5..eeada402f35 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -356,7 +356,8 @@ def save(self, root=None): if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")): state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt")) model_path = _get_path( - self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path") + # self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path") + self._model.path ) for n, p in module.named_parameters(): param_name = name + "." + n diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index ee034aee2b9..7af3aefca94 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -1,6 +1,7 @@ import shutil import sys import unittest +from copy import deepcopy sys.path.insert(0, "./") import torch @@ -16,7 +17,7 @@ class TestLayerWise(unittest.TestCase): @classmethod def setUpClass(self): self.model_name_or_path = "facebook/opt-125m" - self.fp32_model = load_shell(self.model_name_or_path, AutoModelForCausalLM, torchscript=True) + self.fp32_model = load_shell(self.model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -32,29 +33,32 @@ def __len__(self): eval_dataset = TestDataset() self.eval_dataloader = DataLoader(eval_dataset, batch_size=8) + @classmethod + def tearDownClass(cls): + shutil.rmtree("./saved_model", ignore_errors=True) + def test_rtn_lwq(self): conf = PostTrainingQuantConfig( approach="weight_only", recipes={ "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, + # "layer_wise_quant_args": { + # "model_path": "facebook/opt-125m", + # }, "rtn_args": {"enable_full_range": True}, }, ) q_model = quantization.fit( - self.fp32_model, + deepcopy(self.fp32_model), conf, calib_dataloader=self.eval_dataloader, eval_func=lambda x: 0.1, ) ouput_dir = "./saved_model" q_model.save(ouput_dir) - load_model = load(ouput_dir, self.fp32_model, weight_only=True) + load_model = load(ouput_dir, deepcopy(self.empty_model), weight_only=True) self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") - shutil.rmtree(ouput_dir) def test_gptq_lwq(self): conf = PostTrainingQuantConfig( @@ -72,17 +76,13 @@ def test_gptq_lwq(self): recipes={ "gptq_args": {"actorder": True, "mse": True, "perchannel": False}, "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, }, ) - q_model = quantization.fit(self.fp32_model, conf, calib_dataloader=self.eval_dataloader) + q_model = quantization.fit(deepcopy(self.fp32_model), conf, calib_dataloader=self.eval_dataloader) ouput_dir = "./saved_model" q_model.save(ouput_dir) - load_model = load(ouput_dir, self.fp32_model, weight_only=True) + load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True, layer_wise=True) self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") - shutil.rmtree(ouput_dir) if __name__ == "__main__": From 4b1d49f606749c896aa7fca5e80af4d49c9eede4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 05:48:37 +0000 Subject: [PATCH 06/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/adaptor/pytorch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index c9bc8c572d5..4952035e484 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3501,9 +3501,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False) alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5) # device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") - assert ( - model_path is not None - ), "The model_path should not be None." + assert model_path is not None, "The model_path should not be None." device = self.device lw_quant = LayerWiseQuant( q_model._model, From 6e84288c6650f2306303d696996cb03fe4f7e825 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Tue, 17 Oct 2023 13:50:45 +0800 Subject: [PATCH 07/15] update doc Signed-off-by: Guo, Heng --- docs/source/quantization_weight_only.md | 6 ++---- test/algorithm/test_lwq_weight_only.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index b26c5194aa3..c42367671da 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -151,14 +151,11 @@ Large language models (LLMs) have shown exceptional performance across various t from neural_compressor import PostTrainingQuantConfig, quantization from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell -fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) +fp32_model = load_shell(model_name_or_path, torchscript=True) conf = PostTrainingQuantConfig( approach="weight_only", recipes={ "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, "rtn_args": {"enable_full_range": True}, }, ) @@ -171,6 +168,7 @@ q_model = quantization.fit( ) ouput_dir = "./saved_model" q_model.save(ouput_dir) +q_model = load(ouput_dir, deepcopy(self.empty_model), weight_only=True, layer_wise=True) ``` ## Reference diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index 7af3aefca94..1b27eb5ff20 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -57,7 +57,7 @@ def test_rtn_lwq(self): ) ouput_dir = "./saved_model" q_model.save(ouput_dir) - load_model = load(ouput_dir, deepcopy(self.empty_model), weight_only=True) + load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True) self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") def test_gptq_lwq(self): From e23dca5ff02d9080325225098ba1176039a0edd7 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Tue, 17 Oct 2023 16:29:08 +0800 Subject: [PATCH 08/15] fix Signed-off-by: Guo, Heng --- docs/source/quantization_weight_only.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index c42367671da..1c26d22e97e 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -168,7 +168,7 @@ q_model = quantization.fit( ) ouput_dir = "./saved_model" q_model.save(ouput_dir) -q_model = load(ouput_dir, deepcopy(self.empty_model), weight_only=True, layer_wise=True) +q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True) ``` ## Reference From 0f19ce26db59c4314947c87a494c07cac17554dd Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Wed, 18 Oct 2023 17:43:32 +0800 Subject: [PATCH 09/15] Modify device setting method. Test on different models. Signed-off-by: YIYANGCAI --- .../quantization/ptq_weight_only/run-gptq-llm.py | 9 +++++---- neural_compressor/adaptor/torch_utils/gptq.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py index 149a030af09..18d6ad17997 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py @@ -219,12 +219,12 @@ def skip(*args, **kwargs): tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True) model = model.eval() - calib_dataset = load_dataset(args.dataset, split="train") # default - # calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF + # calib_dataset = load_dataset(args.dataset, split="train") # default + calib_dataset = datasets.load_from_disk('/data4/cyy/gptq_inc/pile-10k/') # use this if trouble with connecting to HF calib_dataset = calib_dataset.shuffle(seed=args.seed) calib_evaluator = Evaluator(calib_dataset, tokenizer, args.calib_size, is_calib=True) calib_dataloader = DataLoader( @@ -294,7 +294,8 @@ def skip(*args, **kwargs): dataloader=calib_dataloader, nsamples = args.nsamples, use_max_length = args.use_max_length, - pad_max_length = args.pad_max_length + pad_max_length = args.pad_max_length, + device = DEV, ) results = lm_evaluate( diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 9d438d1581d..133862a2eeb 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -217,6 +217,8 @@ def __init__( # device self.device = device + if str(self.model.device).startswith("cuda"): + self.device = self.model.device self.is_ready = False self.layer_wise = layer_wise @@ -456,7 +458,8 @@ def forward(layer, *args, **kwargs): # Step3: run forward to obtain calibration datasets logger.info("Collecting calibration inputs...") for batch in tqdm(self.dataloader): - batch = move_input_to_device(batch, self.device) + if not self.layer_wise: + batch = move_input_to_device(batch, self.device) try: if isinstance(batch, tuple) or isinstance(batch, list): self.model(batch[0]) @@ -519,7 +522,11 @@ def execute_quantization(self, means=None, stds=None, model_path=None): tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) + if not self.layer_wise: + # if we do not apply layer-wise feature, we still place the entire block on the GPU + transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + else: + transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -569,7 +576,6 @@ def tmp(_, inp, out): for layer_name in sub_layers: handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) idx = self.cache_key_arguments.pop("i") - # import pdb;pdb.set_trace() for j in range(len(self.dataloader)): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) From 1ddd1861ed5456d0101984b6d168b1bc31e83477 Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Wed, 18 Oct 2023 17:46:22 +0800 Subject: [PATCH 10/15] fix hardcodes. Signed-off-by: YIYANGCAI --- .../quantization/ptq_weight_only/run-gptq-llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py index 18d6ad17997..82443f73492 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py @@ -223,8 +223,8 @@ def skip(*args, **kwargs): model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True) model = model.eval() - # calib_dataset = load_dataset(args.dataset, split="train") # default - calib_dataset = datasets.load_from_disk('/data4/cyy/gptq_inc/pile-10k/') # use this if trouble with connecting to HF + calib_dataset = load_dataset(args.dataset, split="train") # default + # calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF calib_dataset = calib_dataset.shuffle(seed=args.seed) calib_evaluator = Evaluator(calib_dataset, tokenizer, args.calib_size, is_calib=True) calib_dataloader = DataLoader( From 1e4eecf2c44be9e8059f9bd38b2d2f57ae060899 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Thu, 19 Oct 2023 09:02:05 +0800 Subject: [PATCH 11/15] change api name Signed-off-by: Guo, Heng --- docs/source/quantization_weight_only.md | 4 ++-- .../adaptor/torch_utils/layer_wise_quant/__init__.py | 2 +- .../adaptor/torch_utils/layer_wise_quant/quantize.py | 2 +- .../adaptor/torch_utils/layer_wise_quant/utils.py | 2 +- test/algorithm/test_layer_wise_quant.py | 6 +++--- test/algorithm/test_lwq_weight_only.py | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index 1c26d22e97e..e9206074809 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -149,9 +149,9 @@ Large language models (LLMs) have shown exceptional performance across various t ### Example ```python from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model -fp32_model = load_shell(model_name_or_path, torchscript=True) +fp32_model = load_empty_model(model_name_or_path, torchscript=True) conf = PostTrainingQuantConfig( approach="weight_only", recipes={ diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py index 6f01b1288bf..f347da31bb9 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """Torch layer-wise quantization module.""" -from .utils import load_shell +from .utils import load_empty_model from .quantize import LayerWiseQuant diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py index 052ed502c17..9e3f8789dad 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py @@ -92,7 +92,7 @@ def __init__( alpha=0.5, ): """Init LayerWiseQuant.""" - # self.q_model = load_shell(pretrained_model_name_or_path, cls) + # self.q_model = load_empty_model(pretrained_model_name_or_path, cls) self.q_model = q_model self.fp32_model = deepcopy(self.q_model) self.path = _get_path(pretrained_model_name_or_path) diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py index 306e4d27ecf..8bd3d32d320 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -107,7 +107,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): return file_path -def load_shell(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): +def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): """Load a empty model.""" is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: # pragma: no cover diff --git a/test/algorithm/test_layer_wise_quant.py b/test/algorithm/test_layer_wise_quant.py index e25036bb021..72a70fa76ae 100644 --- a/test/algorithm/test_layer_wise_quant.py +++ b/test/algorithm/test_layer_wise_quant.py @@ -8,14 +8,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model from neural_compressor.utils.pytorch import load class TestLayerWise(unittest.TestCase): def test_layer_wise(self): model_name_or_path = "facebook/opt-125m" - fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + fp32_model = load_empty_model(model_name_or_path, AutoModelForCausalLM, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -65,7 +65,7 @@ def test_util(self): ) model_name_or_path = "facebook/opt-125m" - model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + model = load_empty_model(model_name_or_path, AutoModelForCausalLM, torchscript=True) children = get_children(model) named_children = get_named_children(model) self.assertEqual(children, [v for k, v in named_children]) diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index 1b27eb5ff20..6bfea5e8d48 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model from neural_compressor.utils.pytorch import load @@ -17,7 +17,7 @@ class TestLayerWise(unittest.TestCase): @classmethod def setUpClass(self): self.model_name_or_path = "facebook/opt-125m" - self.fp32_model = load_shell(self.model_name_or_path, torchscript=True) + self.fp32_model = load_empty_model(self.model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): From 23c02fc0c2aa991581378194aea6ba34d197bae2 Mon Sep 17 00:00:00 2001 From: Guo Date: Thu, 19 Oct 2023 15:44:37 -0700 Subject: [PATCH 12/15] using hf functiont to set Q and clean memo Signed-off-by: Guo --- neural_compressor/adaptor/torch_utils/gptq.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 133862a2eeb..19764aa6f22 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import math import random import re @@ -616,11 +617,16 @@ def tmp(_, inp, out): full_layer_name = self.get_full_layer_name(layer_name, block_idx) for n, p in sub_layer.named_parameters(): param_name = full_layer_name + "." + n - value = load_value(self.model, param_name, model_path) - set_module_tensor_to_device(self.model, param_name, self.device, value) - sub_layer.weight.data = Q + if n == 'weight': + set_module_tensor_to_device(self.model, param_name, self.device, Q) + else: + value = load_value(self.model, param_name, model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + # sub_layer.weight.data = Q torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") clean_module_weight(sub_layer) + del Q + gc.collect() else: sub_layers[layer_name].weight.data = Q gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} From f79c4e3208c62b5fee6d2d8b3b2565208c8eabde Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 05:47:22 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/adaptor/torch_utils/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 19764aa6f22..daaa08df8e3 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -617,7 +617,7 @@ def tmp(_, inp, out): full_layer_name = self.get_full_layer_name(layer_name, block_idx) for n, p in sub_layer.named_parameters(): param_name = full_layer_name + "." + n - if n == 'weight': + if n == "weight": set_module_tensor_to_device(self.model, param_name, self.device, Q) else: value = load_value(self.model, param_name, model_path) From 5e34781215b21f6123e987fe3d4e12b7f448a3e7 Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Tue, 24 Oct 2023 12:21:02 +0800 Subject: [PATCH 14/15] clean ut Signed-off-by: Guo, Heng --- test/algorithm/test_layer_wise_quant.py | 4 ++-- test/algorithm/test_lwq_weight_only.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/algorithm/test_layer_wise_quant.py b/test/algorithm/test_layer_wise_quant.py index 72a70fa76ae..2eef1e89bd9 100644 --- a/test/algorithm/test_layer_wise_quant.py +++ b/test/algorithm/test_layer_wise_quant.py @@ -15,7 +15,7 @@ class TestLayerWise(unittest.TestCase): def test_layer_wise(self): model_name_or_path = "facebook/opt-125m" - fp32_model = load_empty_model(model_name_or_path, AutoModelForCausalLM, torchscript=True) + fp32_model = load_empty_model(model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -65,7 +65,7 @@ def test_util(self): ) model_name_or_path = "facebook/opt-125m" - model = load_empty_model(model_name_or_path, AutoModelForCausalLM, torchscript=True) + model = load_empty_model(model_name_or_path, torchscript=True) children = get_children(model) named_children = get_named_children(model) self.assertEqual(children, [v for k, v in named_children]) diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index 6bfea5e8d48..3735f0b2581 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -6,7 +6,6 @@ sys.path.insert(0, "./") import torch from torch.utils.data import DataLoader, Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model From 74388b31e62c7e0de87be3ac0e58363bc7d8ba0e Mon Sep 17 00:00:00 2001 From: "Guo, Heng" Date: Tue, 24 Oct 2023 13:28:52 +0800 Subject: [PATCH 15/15] update support matrix Signed-off-by: Guo, Heng --- docs/source/quantization_weight_only.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index e9206074809..69275c587d2 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -143,7 +143,7 @@ Large language models (LLMs) have shown exceptional performance across various t |:--------------:|:----------:| | RTN | ✔ | | AWQ | ✕ | -| GPTQ | ✕ | +| GPTQ | ✔ | | TEQ | ✕ | ### Example