Skip to content

Commit 19ab16c

Browse files
authored
add mse_range for weight_only RTN algo (#1157)
* add mse_range for weight_only RTN algo Signed-off-by: Cheng, Zixuan <[email protected]> * minor fix Signed-off-by: Cheng, Zixuan <[email protected]> * fix mse calculation Signed-off-by: Cheng, Zixuan <[email protected]> * minor fix Signed-off-by: Cheng, Zixuan <[email protected]> * fix for UT coverage Signed-off-by: Cheng, Zixuan <[email protected]> * fix code Signed-off-by: Cheng, Zixuan <[email protected]> --------- Signed-off-by: Cheng, Zixuan <[email protected]>
1 parent 66f7c10 commit 19ab16c

File tree

4 files changed

+68
-5
lines changed

4 files changed

+68
-5
lines changed

docs/source/quantization_weight_only.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ There are many excellent works for weight only quantization to improve its accur
4444
| rtn_args | default value | comments |
4545
|:----------:|:-------------:|:-------------------------------------------------------------------:|
4646
| sym_full_range | False | Whether use -2**(bits-1) in sym scheme, for example, |
47+
| mse_range | False | Whether search for the best clip range from range [0.805, 1.0, 0.005] |
4748
| return_int | False | Whether return compressed model with int data type |
4849

4950
**AWQ arguments**:
5051
| awq_args | default value | comments |
5152
|:----------:|:-------------:|:-------------------------------------------------------------------:|
5253
| auto_scale | True | Whether search for best scales based on activation distribution |
53-
| mse_range | True | Whether search for the best clip range from range [0.89, 1.0, 0.01] |
54+
| mse_range | True | Whether search for the best clip range from range [0.91, 1.0, 0.01] |
5455
| folding | False | False will allow insert mul before linear when the scale cannot be absorbed by last layer, else won't |
5556

5657
**GPTQ arguments**:

neural_compressor/adaptor/pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4338,8 +4338,10 @@ def rtn_quantize(self, model, tune_cfg):
43384338
logger.info("quantizing with the round-to-nearest algorithm")
43394339
if 'rtn_args' in self.recipes:
43404340
sym_full_range = self.recipes['rtn_args'].get('sym_full_range', False)
4341-
else:
4341+
mse_range = self.recipes['rtn_args'].get('mse_range', False)
4342+
else: # pragma: no cover
43424343
sym_full_range=False
4344+
mse_range=False
43434345
from .torch_utils.weight_only import rtn_quantize
43444346
from .torch_utils.util import fetch_module, set_module
43454347
for key, config in tune_cfg['op'].items():
@@ -4356,7 +4358,8 @@ def rtn_quantize(self, model, tune_cfg):
43564358
m = fetch_module(model, op_name)
43574359
m = rtn_quantize(m, num_bits, group_size, scheme,
43584360
return_int=False,
4359-
sym_full_range=sym_full_range)
4361+
sym_full_range=sym_full_range,
4362+
mse_range=mse_range)
43604363
set_module(model, op_name, m)
43614364
return model
43624365

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,50 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0,
208208
return weight
209209

210210

211+
def search_clip(m, num_bits, group_size, scheme, sym_full_range):
212+
"""Search best clip range of each linears in current block.
213+
214+
Args:
215+
m (torch.nn.Module): torch module.
216+
num_bits (int, optional): num bits.
217+
group_size (int, optional): how many elements share one scale/zp.
218+
scheme (str, optional): sym or asym.
219+
sym_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
220+
221+
Returns:
222+
best_clip_ratio (float): best percentile of clip
223+
224+
"""
225+
org_weight = m.weight.data
226+
logger.info("Searching the best clip range with RTN algorithm")
227+
best_error = float('inf')
228+
best_clip_ratio = None
229+
n_grid = 200
230+
max_shrink = 0.2
231+
history = []
232+
for i_s in range(int(max_shrink * n_grid)):
233+
ratio = (1 - i_s / n_grid) # 1, 0.805-1.0
234+
cur_weight = quant_weight(
235+
m.weight.data,
236+
num_bits=num_bits,
237+
group_size=group_size,
238+
scheme=scheme,
239+
full_range=sym_full_range,
240+
quantile=ratio,
241+
)
242+
loss = (org_weight - cur_weight).float().pow(2).mean().item()
243+
history.append(loss)
244+
is_best = loss < best_error
245+
if is_best:
246+
best_error = loss
247+
best_clip_ratio = ratio
248+
logger.debug("The loss history of different clip range:{}".format(history))
249+
logger.debug("The best clip ratio is {}".format(best_clip_ratio))
250+
return best_clip_ratio
251+
211252
def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
212253
quantile=1.0, weight_config={}, return_int=False,
213-
sym_full_range=False, **kwargs):
254+
sym_full_range=False, mse_range=False, **kwargs):
214255
"""Quant the model with round to nearst method.
215256
216257
Args:
@@ -234,6 +275,8 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
234275
Defaults to False.
235276
sym_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
236277
Defaults to False.
278+
mse_range (bool, optional): Whether search clip range.
279+
Defaults to True.
237280
238281
Returns:
239282
model: fake quantized torch module
@@ -264,6 +307,8 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
264307
logger.info(f"Skip {name}")
265308
continue
266309
weight = m.weight
310+
if mse_range:
311+
quantile = search_clip(m, num_bits, group_size, scheme, sym_full_range)
267312
if return_int:
268313
from .model_wrapper import WeightOnlyLinear
269314
int_weight, scale, zp = quant_weight(

test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def test_RTN_quant(self):
8989
compressed_model = q_model.export_compressed_model()
9090
out3 = compressed_model(input)
9191
self.assertTrue(torch.all(out3==out2))
92-
9392
model = Model()
9493
out1 = model(input)
9594

@@ -108,6 +107,21 @@ def test_RTN_quant(self):
108107
out3 = compressed_model(input)
109108
self.assertTrue(torch.all(out3==out2))
110109

110+
model = Model()
111+
out1 = model(input)
112+
conf = PostTrainingQuantConfig(
113+
approach='weight_only',
114+
recipes={
115+
# By default, sym_full_range is False and 4 bit sym will only use range [-7,7].
116+
# When mse_range is set to True, enable clip for weight by checking mse.
117+
'rtn_args': {'sym_full_range': True, 'mse_range': True}
118+
}
119+
)
120+
q_model = quantization.fit(model, conf)
121+
out2 = q_model(input)
122+
self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1)))
123+
self.assertFalse(torch.all(out1 == out2))
124+
111125
model = Model()
112126
out1 = model(input)
113127
conf = PostTrainingQuantConfig(

0 commit comments

Comments
 (0)