Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ To support low memory inference, Neural Compressor implemented WeightOnlyLinear,
| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64] |
| compression_dim | 1 | 0 means output channel while 1 means input channel |
| scale_dtype | torch.float32 | Data type for scale and bias |
| use_hf_format | False | Whether to use the popular format present on HuggingFace hub |

**Note:** HuggingFace format is quite special, the main differences are as follows:

> 1: Compression Dimension: weight = 1, zero = 0 and both are transposed.
> 2: Zero Point: zero_point-= 1 before compression. zero_point is always required even for sym.
> 3: Group Index: Use the same number for a group instead of recording channel order.


### **User Code Example**
```python
Expand All @@ -119,12 +127,14 @@ conf = PostTrainingQuantConfig(
)
q_model = quantization.fit(model, conf, eval_func=eval_func)
q_model.save("saved_results")
compressed_model = q_model.export_compressed_model(
compression_dtype=torch.int32,
compression_dim=1,
scale_dtype=torch.float16,
)
compressed_model = q_model.export_compressed_model()
torch.save(compressed_model.state_dict(), "compressed_model.pt")
# or
model = Model()
compressed_model = export_compressed_model(
model,
saved_dir="saved_results",
)
```

The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model.
Expand Down
236 changes: 136 additions & 100 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,12 @@ def __init__(
scale_dtype=torch.float32,
compression_dtype=torch.int32,
compression_dim=1,
gptq_perm=False,
g_idx=False,
device="cpu",
use_hf_format=False,
):
super().__init__()
self.use_hf_format = use_hf_format
self.dtype = dtype
if "int" not in self.dtype: # for nf4, fp4
from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING
Expand Down Expand Up @@ -249,69 +251,105 @@ def __init__(
assert compression_dim in [0, 1], (
"Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel."
)
self.register_buffer(
"scale",
torch.zeros(
(out_features, math.ceil(in_features / self.groupsize)),
dtype=self.float_type,
).to(device),
)
if compression_dim == 1:
if self.use_hf_format:
self.register_buffer(
"packed_weight",
"scales",
torch.zeros(
(out_features, math.ceil(in_features / self.n_pack)),
(math.ceil(in_features / self.groupsize), out_features),
dtype=self.float_type,
).to(device),
)
self.scales = self.scales.T
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(in_features / self.n_pack), out_features),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.register_buffer(
"packed_zp",
torch.zeros(
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
else:
self.qweight = self.qweight.T
self.register_buffer(
"packed_weight",
"qzeros",
torch.zeros(
(math.ceil(out_features / self.n_pack), in_features),
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.qzeros = self.qzeros.T
else:
self.register_buffer(
"scales",
torch.zeros(
(out_features, math.ceil(in_features / self.groupsize)),
dtype=self.float_type,
).to(device),
)
if compression_dim == 1:
self.register_buffer(
"packed_zp",
"qweight",
torch.zeros(
(math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)),
(out_features, math.ceil(in_features / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.register_buffer(
"qzeros",
torch.zeros(
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
else:
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(out_features / self.n_pack), in_features),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.register_buffer(
"qzeros",
torch.zeros(
(math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)),
dtype=self.compressed_dtype,
).to(device),
)
if g_idx:
self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device))
else:
self.g_idx = None
if bias:
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
else:
self.bias = None
if gptq_perm:
self.register_buffer("gptq_perm", torch.zeros(in_features, dtype=torch.int32).to(device))
else:
self.gptq_perm = None

def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
def pack(self, int_weight, scale, zp, bias, g_idx=None):
int_weight = int_weight.to(self.device)
if self.use_hf_format and zp is None:
# to avoid overflow
int_weight = int_weight.type(torch.int32)
shift_bias = 2 ** (self.bits - 1)
int_weight += shift_bias
zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias
if bias is not None:
assert hasattr(self, "bias"), "bias is not set when initializing."
self.bias = bias.type(self.float_type).to(self.device)
if gptq_perm is not None:
assert hasattr(self, "gptq_perm"), "gptq_perm is not set when initializing."
self.gptq_perm = gptq_perm.type(torch.int32).to(self.device)
assert scale.shape == self.scale.shape, "Scale shape is mismatched."
self.scale = scale.type(self.float_type).to(self.device)
if self.compression_dim == 0:
if g_idx is not None:
assert hasattr(self, "g_idx"), "g_idx is not set when initializing."
self.g_idx = g_idx.type(torch.int32).to(self.device)
if self.use_hf_format:
invperm = torch.argsort(self.g_idx)
self.g_idx = invperm // self.groupsize
self.g_idx = self.g_idx.type(torch.int32).to(self.device)
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_hf_format and self.compression_dim == 0:
int_weight = int_weight.T
self.packed_weight = self.packed_weight.T
self.qweight = self.qweight.T
origin_shape = int_weight.shape
target_shape = self.packed_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(self.device)

Expand All @@ -323,121 +361,112 @@ def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.packed_weight[:, j] |= tmp[:, e]
if self.compression_dim == 0:
self.packed_weight = self.packed_weight.T
self.qweight[:, j] |= tmp[:, e]
if not self.use_hf_format and self.compression_dim == 0:
self.qweight = self.qweight.T

if zp is not None:
zp = zp.to(self.device)
if self.compression_dim == 0:
if self.use_hf_format:
zp -= 1
if self.use_hf_format or self.compression_dim == 0:
zp = zp.T
self.packed_zp = self.packed_zp.T
assert hasattr(self, "packed_zp"), "zp is not set when initializing."
target_shape = self.packed_zp.shape
self.qzeros = self.qzeros.T
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = zp[:, start:end].type(self.compressed_dtype)
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.packed_zp[:, j] |= tmp[:, e]
if self.compression_dim == 0:
self.packed_zp = self.packed_zp.T
self.qzeros[:, j] |= tmp[:, e]
if self.use_hf_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
if self.use_hf_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.g_idx = self.g_idx
self.qzeros = self.qzeros.T

def recover(self):
logger.debug(f"Recovering {self} weight")
device = self.scale.device
if self.use_hf_format:
# Prevent broken id links of self.scales and self.scales
self.scales = self.scales.T
self.qweight = self.qweight.T
self.g_idx = self.g_idx
self.qzeros = self.qzeros.T
device = self.scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
if self.g_idx is None:
# used for recovering fp32_weight
self.g_idx = torch.tensor([i // self.groupsize for i in range(self.in_features)], dtype=torch.int32)
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(device)
if hasattr(self, "packed_zp"):
if hasattr(self, "qzeros"):
weight_dtype = torch.uint8
else:
weight_dtype = torch.int8
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
packed_weight = self.packed_weight
if self.compression_dim == 0:
qweight = self.qweight
if not self.use_hf_format and self.compression_dim == 0:
weight = weight.T
packed_weight = packed_weight.T
qweight = qweight.T
origin_shape = weight.shape
target_shape = packed_weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = packed_weight[:, j]
tmp = qweight[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
if weight_dtype == torch.uint8:
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if self.compression_dim == 0:
if not self.use_hf_format and self.compression_dim == 0:
weight = weight.T
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
new_weight += torch.where(weight == k, v, 0)
weight = new_weight
# unpack zero_point
if hasattr(self, "packed_zp"):
if hasattr(self, "qzeros"):
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp
zp = torch.zeros(self.scale.shape, dtype=zp_dtype).to(device)
packed_zp = self.packed_zp
if self.compression_dim == 0:
zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros
if self.use_hf_format or self.compression_dim == 0:
zp = zp.T
packed_zp = packed_zp.T
qzeros = qzeros.T
origin_shape = zp.shape
target_shape = packed_zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = packed_zp[:, j]
tmp = qzeros[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.compression_dim == 0:
if self.use_hf_format or self.compression_dim == 0:
zp = zp.T
if self.use_hf_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
# recover fp32 weight with int_weight, scale, and zero_point
left_element = self.in_features % self.groupsize
if left_element != 0:
split_index = self.in_features // self.groupsize * self.groupsize
weight1 = weight[:, :-split_index].reshape(-1, self.groupsize)
scale1 = self.scale[:, :-1].reshape(-1, 1)
zp1 = zp[:, :-1].reshape(-1, 1)
weight1 = ((weight1 - zp1) * scale1).reshape(self.out_features, -1)
weight2 = weight[:, -split_index:]
scale2 = self.scale[:, -1:]
zp2 = zp[:, -1].reshape(-1, 1)
weight2 = (weight2 - zp2) * scale2
fp32_weight = torch.cat((weight1, weight2), dim=1)
else:
weight = weight.reshape(-1, self.groupsize)
scale = self.scale.reshape(-1, 1)
zp = zp.reshape(-1, 1)
fp32_weight = ((weight - zp) * scale).reshape(self.out_features, -1)
for idx in range(self.in_features):
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]]
else:
# recover fp32 weight with int_weight, scale
left_element = self.in_features % self.groupsize
if left_element != 0:
split_index = self.in_features // self.groupsize * self.groupsize
weight1 = weight[:, :split_index].reshape(-1, self.groupsize)
scale1 = self.scale[:, :-1].reshape(-1, 1)
weight1 = (weight1 * scale1).reshape(self.out_features, -1)
weight2 = weight[:, split_index:]
scale2 = self.scale[:, -1:]
weight2 = weight2 * scale2
fp32_weight = torch.cat((weight1, weight2), dim=1)
else:
weight = weight.reshape(-1, self.groupsize)
scale = self.scale.reshape(-1, 1)
fp32_weight = (weight * scale).reshape(self.out_features, -1)
if self.gptq_perm is not None:
invperm = torch.argsort(self.gptq_perm)
fp32_weight = fp32_weight[:, invperm]
for idx in range(self.in_features):
fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]]
return fp32_weight

def forward(self, input):
Expand All @@ -453,9 +482,16 @@ def forward(self, input):
return F.linear(input, weight, self.bias)

def extra_repr(self) -> str:
return "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format(
self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None
tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format(
self.in_features,
self.out_features,
self.bits,
self.groupsize,
self.bias is not None,
)
if self.use_hf_format:
tmp_str += ", use_hf_format=True"
return tmp_str


class FakeAffineTensorQuantFunction(Function):
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def rtn_quantize(
compression_dim = kwargs.get("compression_dim", 1)
scale_dtype = kwargs.get("scale_dtype", torch.float32)
device = kwargs.get("device", "cpu")
use_hf_format = kwargs.get("use_hf_format", False)
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
Expand Down Expand Up @@ -448,6 +449,7 @@ def rtn_quantize(
compression_dim=compression_dim,
scale_dtype=scale_dtype,
device=device,
use_hf_format=use_hf_format,
)
new_module.pack(int_weight, scale, zp, m.bias)
if name == "":
Expand Down
Loading