11import torch
22import torch .nn as nn
33
4- import os
5- import sys
6- # append the path to the naive_intNwo.py file
7- sys .path .append (os .path .join (os .path .dirname (os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))), "torchao/quantization/prototype/mixed_precision/scripts" ))
8- from naive_intNwo import intN_weight_only
4+ from torchao .quantization .prototype .mixed_precision .scripts .naive_intNwo import intN_weight_only
95
106from torchao .quantization import quantize_ , int8_weight_only , int4_weight_only
117
1814)
1915
2016def test_weight_only_quant (quantization_bit = 2 , symmetric = False ):
21- for x_shape in [[32 , 64 ], [80 , 80 , 80 , 64 ], [16 , 64 , 64 ]]:
22- x = torch .randn (* x_shape )
23- m = nn .Sequential (nn .Linear (64 , 80 ))
17+ for x_shape in [[64 , 32 ], [80 , 80 , 80 , 32 ], [16 , 64 , 32 ]]:
18+ x = torch .randn (* x_shape , dtype = torch . bfloat16 )
19+ m = nn .Sequential (nn .Linear (32 , 80 )). bfloat16 ( )
2420 y_ref = m (x )
25- quantize_ (m , intN_weight_only (n = quantization_bit , group_size = 16 , symmetric = symmetric ))
21+ quantize_ (m , intN_weight_only (n = quantization_bit , group_size = 32 , symmetric = symmetric ))
2622 y_wo = m (x )
2723 sqnr = compute_error (y_ref , y_wo )
2824 #SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
@@ -31,7 +27,7 @@ def test_weight_only_quant(quantization_bit=2, symmetric=False):
3127
3228
3329# test if the asymmetric and symmetric quantization API works with different bit widths
34- for i in [2 ,3 , 5 , 6 , 8 ]:
30+ for i in [2 , 3 , 4 , 5 , 6 , 8 ]:
3531 #test for asymmetric quantization
3632 try :
3733 test_weight_only_quant (i , False )
0 commit comments