|
18 | 18 | ) |
19 | 19 |
|
20 | 20 | def test_weight_only_quant(quantization_bit=2, symmetric=False): |
21 | | - for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: |
| 21 | + for x_shape in [[32, 64], [80, 80, 80, 64], [16, 64, 64]]: |
22 | 22 | x = torch.randn(*x_shape) |
23 | | - m = nn.Sequential(nn.Linear(4, 5)) |
| 23 | + m = nn.Sequential(nn.Linear(64, 80)) |
24 | 24 | y_ref = m(x) |
25 | | - quantize_(m, intN_weight_only(n=quantization_bit, group_size=2, symmetric=symmetric)) |
| 25 | + quantize_(m, intN_weight_only(n=quantization_bit, group_size=16, symmetric=symmetric)) |
26 | 26 | y_wo = m(x) |
27 | 27 | sqnr = compute_error(y_ref, y_wo) |
28 | | - print(sqnr) |
29 | | - assert sqnr > 44.0, "sqnr: {} is too low".format(sqnr) |
| 28 | + #SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization |
| 29 | + #e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills |
| 30 | + assert sqnr > 44.0-(8-quantization_bit)*6.02, "sqnr: {} is too low".format(sqnr) |
30 | 31 |
|
31 | 32 |
|
32 | 33 | # test if the asymmetric and symmetric quantization API works with different bit widths |
33 | | -for i in range(2, 9): |
| 34 | +for i in [2,3,5,6,8]: |
34 | 35 | #test for asymmetric quantization |
35 | 36 | try: |
36 | 37 | test_weight_only_quant(i, False) |
|
0 commit comments