Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b70312a
support lwq for gptq
n1ck-guo Oct 13, 2023
897cee3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2023
adfc8a2
fix cuda error
n1ck-guo Oct 16, 2023
dae9d50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2023
faeb54d
new api
n1ck-guo Oct 17, 2023
7f0282e
Merge branch 'master' into hengguo/lwq_weight_only
n1ck-guo Oct 17, 2023
4b1d49f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2023
6e84288
update doc
n1ck-guo Oct 17, 2023
cddf10b
Merge branch 'hengguo/lwq_weight_only' of https://github.com/intel/ne…
n1ck-guo Oct 17, 2023
e23dca5
fix
n1ck-guo Oct 17, 2023
f9b1c63
Merge branch 'master' into hengguo/lwq_weight_only
YIYANGCAI Oct 18, 2023
0f19ce2
Modify device setting method. Test on different models.
YIYANGCAI Oct 18, 2023
1ddd186
fix hardcodes.
YIYANGCAI Oct 18, 2023
1e4eecf
change api name
n1ck-guo Oct 19, 2023
dba4cf0
Merge branch 'hengguo/lwq_weight_only' of https://github.com/intel/ne…
n1ck-guo Oct 19, 2023
23c02fc
using hf functiont to set Q and clean memo
n1ck-guo Oct 19, 2023
f79c4e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2023
5e34781
clean ut
n1ck-guo Oct 24, 2023
e38e9bc
Merge branch 'hengguo/lwq_weight_only' of https://github.com/intel/ne…
n1ck-guo Oct 24, 2023
74388b3
update support matrix
n1ck-guo Oct 24, 2023
744d685
Merge branch 'master' into hengguo/lwq_weight_only
chensuyue Oct 26, 2023
c1c0a92
Merge branch 'master' into hengguo/lwq_weight_only
n1ck-guo Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,19 @@ Large language models (LLMs) have shown exceptional performance across various t
|:--------------:|:----------:|
| RTN | ✔ |
| AWQ | ✕ |
| GPTQ | ✕ |
| GPTQ | ✔ |
| TEQ | ✕ |

### Example
```python
from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model

fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)
fp32_model = load_empty_model(model_name_or_path, torchscript=True)
conf = PostTrainingQuantConfig(
approach="weight_only",
recipes={
"layer_wise_quant": True,
"layer_wise_quant_args": {
"model_path": "facebook/opt-125m",
},
"rtn_args": {"enable_full_range": True},
},
)
Expand All @@ -201,6 +198,7 @@ q_model = quantization.fit(
)
ouput_dir = "./saved_model"
q_model.save(ouput_dir)
q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True)
```

## Reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def skip(*args, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True)
model = model.eval()

Expand Down Expand Up @@ -294,7 +294,8 @@ def skip(*args, **kwargs):
dataloader=calib_dataloader,
nsamples = args.nsamples,
use_max_length = args.use_max_length,
pad_max_length = args.pad_max_length
pad_max_length = args.pad_max_length,
device = DEV,
)

results = lm_evaluate(
Expand Down
51 changes: 37 additions & 14 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3502,13 +3502,13 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
):
from .torch_utils.layer_wise_quant import LayerWiseQuant

model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
# model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
model_path = model._model.path
smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False)
alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5)
assert (
model_path is not None
), "the layer_wise_quant_args should have args model_path to load the weight of model."
device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu")
# device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu")
assert model_path is not None, "The model_path should not be None."
device = self.device
lw_quant = LayerWiseQuant(
q_model._model,
model_path,
Expand Down Expand Up @@ -4541,14 +4541,12 @@ def rtn_quantize(self, model, tune_cfg):
# for layer_wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if recipe_cfgs.get("layer_wise_quant", False):
from neural_compressor.config import options

from .torch_utils.layer_wise_quant.utils import _get_path, load_module
from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module

lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir")
os.makedirs(lwq_workspace, exist_ok=True)
model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
assert model_path, "model_path should specify in layer_wise_quant_args."
os.makedirs(LWQ_WORKSPACE, exist_ok=True)
# model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
model_path = model.path
assert model_path, "model_path should not be None."
model_path = _get_path(model_path)

for key, config in tune_cfg["op"].items():
Expand Down Expand Up @@ -4584,7 +4582,7 @@ def rtn_quantize(self, model, tune_cfg):
# save and clean weight
from .torch_utils.layer_wise_quant.utils import clean_module_weight

torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt"))
torch.save(m.state_dict(), os.path.join(LWQ_WORKSPACE, f"{op_name}.pt"))
clean_module_weight(m)
set_module(model, op_name, m)
if recipe_cfgs.get("layer_wise_quant", False):
Expand Down Expand Up @@ -4619,6 +4617,23 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
...
}
"""
# for layer_wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
model_path = None
layer_wise = False
if recipe_cfgs.get("layer_wise_quant", False):
layer_wise = True
from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks

os.makedirs(LWQ_WORKSPACE, exist_ok=True)
# model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
model_path = model.path
assert model_path, "model_path should not be None."
model_path = _get_path(model_path)
lwq_handles = register_weight_hooks(
model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE
)

weight_config = {}
for key, config in tune_cfg["op"].items():
op_name, op_type = key
Expand All @@ -4643,7 +4658,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
)
# tune_cfg => weight_config
model, quantization_perm = gptq_quantize(
model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device
model,
weight_config,
dataloader,
nsamples,
use_max_length,
pad_max_length,
self.device,
layer_wise,
model_path,
)
return model, quantization_perm

Expand Down
104 changes: 81 additions & 23 deletions neural_compressor/adaptor/torch_utils/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import math
import random
import re
Expand Down Expand Up @@ -175,6 +176,7 @@ def __init__(
use_max_length=True,
pad_max_length=2048,
device=None,
layer_wise=False,
):
"""
Args:
Expand Down Expand Up @@ -215,9 +217,13 @@ def __init__(
self.check_layer_config()

# device
self.device = model.device
self.device = device
if str(self.model.device).startswith("cuda"):
self.device = self.model.device
self.is_ready = False

self.layer_wise = layer_wise

# dataloader
self.use_max_length = use_max_length
self.pad_max_length = pad_max_length
Expand Down Expand Up @@ -438,11 +444,13 @@ def forward(layer, *args, **kwargs):
raise ValueError

# Step1: fetch the embeddings and other layers before the transformer stack.
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer = embedding_layer.to(self.device)
if not self.layer_wise:
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer = embedding_layer.to(self.device)

# Step2: modify the first transformer block's forward function to obtain inputs for calibration
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
if not self.layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
Expand All @@ -451,7 +459,8 @@ def forward(layer, *args, **kwargs):
# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
for batch in tqdm(self.dataloader):
batch = move_input_to_device(batch, self.device)
if not self.layer_wise:
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0])
Expand All @@ -473,9 +482,10 @@ def forward(layer, *args, **kwargs):

# Step 4: restore original forward function, relocate layers back to cpu.
self.gptq_related_blocks["transformers"][0].forward = forward_cache
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer.to(self.device)
if not self.layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer.to(self.device)
torch.cuda.empty_cache()
# end
logger.info("GPTQ quantization prepared.")
Expand All @@ -501,7 +511,7 @@ def update_blockwise_hidden_states(self, outs):
self.cache_positional_arguments[0] = outs[:]

@torch.no_grad()
def execute_quantization(self, means=None, stds=None):
def execute_quantization(self, means=None, stds=None, model_path=None):
"""Run quantization."""
# Step1: prepare quantization (calibration datasets)

Expand All @@ -513,7 +523,11 @@ def execute_quantization(self, means=None, stds=None):
tblock_length = len(self.gptq_related_blocks["transformers"])
for block_idx in range(tblock_length):
logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..")
transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device)
if not self.layer_wise:
# if we do not apply layer-wise feature, we still place the entire block on the GPU
transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device)
else:
transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device)
# Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized.
sub_layers = find_layers(transformer_block)
sub_layers_to_quant = {}
Expand All @@ -534,8 +548,16 @@ def execute_quantization(self, means=None, stds=None):
# weight_config_this_layer = self.weight_config.get(
# self.get_full_layer_name(layer_name, block_idx), None
# )
weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name])
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
weight_config_this_layer = self.get_layer_config(full_layer_name)
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import load_value

W = load_value(self.model, full_layer_name + ".weight", model_path)
else:
W = sub_layers[layer_name].weight.data.clone()

gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device)
# gptq_for_this_block[layer_name].quantizer = Quantizer()
gptq_for_this_block[layer_name].quantizer.configure(
weight_config_this_layer["wbits"],
Expand All @@ -555,7 +577,6 @@ def tmp(_, inp, out):
for layer_name in sub_layers:
handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name)))
idx = self.cache_key_arguments.pop("i")
# import pdb;pdb.set_trace()
for j in range(len(self.dataloader)):
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
Expand All @@ -570,12 +591,44 @@ def tmp(_, inp, out):
# )
weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
logger.info(f"Quantizing layer {layer_name}")
scale, zp = gptq_for_this_block[layer_name].fasterquant(
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import load_value

full_layer_name = self.get_full_layer_name(layer_name, block_idx)
W = load_value(self.model, full_layer_name + ".weight", model_path)
else:
W = sub_layers[layer_name].weight.data.clone()
scale, zp, Q = gptq_for_this_block[layer_name].fasterquant(
W,
blocksize=weight_config_this_layer["block_size"],
percdamp=weight_config_this_layer["percdamp"],
groupsize=weight_config_this_layer["group_size"],
act_order=weight_config_this_layer["act_order"],
)
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import (
LWQ_WORKSPACE,
clean_module_weight,
load_value,
set_module_tensor_to_device,
)

sub_layer = sub_layers[layer_name]
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
for n, p in sub_layer.named_parameters():
param_name = full_layer_name + "." + n
if n == "weight":
set_module_tensor_to_device(self.model, param_name, self.device, Q)
else:
value = load_value(self.model, param_name, model_path)
set_module_tensor_to_device(self.model, param_name, self.device, value)
# sub_layer.weight.data = Q
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
clean_module_weight(sub_layer)
del Q
gc.collect()
else:
sub_layers[layer_name].weight.data = Q
gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale}
if not weight_config_this_layer["sym"]:
gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp
Expand All @@ -594,7 +647,10 @@ def tmp(_, inp, out):
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0]
outs.append(out)
self.cache_key_arguments["i"] = idx
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
if self.layer_wise:
self.gptq_related_blocks["transformers"][block_idx] = transformer_block
else:
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
del gptq_for_this_block
torch.cuda.empty_cache()
# iteratively replace the input with output, thus layerwise quantization can continue.
Expand All @@ -617,10 +673,10 @@ class GPTQ:
GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323)
"""

def __init__(self, layer):
def __init__(self, layer, W, device="cpu"):
self.layer = layer
self.device = self.layer.weight.device
W = layer.weight.data.clone()
self.device = device
# W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
Expand Down Expand Up @@ -661,8 +717,9 @@ def add_batch(self, inp, out):
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix

def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
W = self.layer.weight.data.clone()
def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
# W = self.layer.weight.data.clone()
weight_shape, weight_dtype = W.shape, W.data.dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
Expand Down Expand Up @@ -740,7 +797,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
# logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")
# logger.info(f"{torch.sum(Losses)}")

if self.device != torch.device("cpu"):
if str(self.device).startswith("cuda"):
torch.cuda.synchronize()
logger.info(f"time {(time.time() - tick)}")
logger.info(f"error {torch.sum(Losses).item()}")
Expand All @@ -751,7 +808,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals

if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
# self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
Q = Q.reshape(weight_shape).to(weight_dtype)
if DEBUG:
logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")

Expand All @@ -760,7 +818,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero
return scale, zero, Q

def free(self):
if DEBUG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch layer-wise quantization module."""
from .utils import load_shell
from .utils import load_empty_model
from .quantize import LayerWiseQuant
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
update_module,
)

TMP_DIR = os.path.join(default_workspace, "layer_wise_quant_tmp_dir")
TMP_DIR = os.path.join(default_workspace, "lwq_tmpdir")


def mk_tmp_dir():
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
alpha=0.5,
):
"""Init LayerWiseQuant."""
# self.q_model = load_shell(pretrained_model_name_or_path, cls)
# self.q_model = load_empty_model(pretrained_model_name_or_path, cls)
self.q_model = q_model
self.fp32_model = deepcopy(self.q_model)
self.path = _get_path(pretrained_model_name_or_path)
Expand Down
Loading