1+ import tempfile
2+ import unittest
3+
4+ import torch
5+ from torch .testing ._internal import common_utils
16from torch .testing ._internal .common_utils import (
27 TestCase ,
38 run_tests ,
49)
10+
11+ from torchao .dtypes import SemiSparseLayout
512from torchao .quantization import (
13+ float8_weight_only ,
614 int4_weight_only ,
7- int8_weight_only ,
815 int8_dynamic_activation_int4_weight ,
916 int8_dynamic_activation_int8_weight ,
10- int8_dynamic_activation_int8_semi_sparse_weight ,
11- float8_weight_only ,
17+ int8_weight_only ,
1218)
1319from torchao .quantization .quant_primitives import MappingType
14- from torchao .dtypes import SemiSparseLayout
15- from torch .testing ._internal import common_utils
1620from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
1721
18- import torch
19- import unittest
20- import tempfile
21-
2222is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
2323
2424
@@ -33,7 +33,9 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3333 base_functions .append (int4_weight_only (group_size = 32 ))
3434
3535 if do_sparse :
36- base_functions .append (int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ()))
36+ base_functions .append (
37+ int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
38+ )
3739
3840 if is_cuda_8_9 :
3941 base_functions .append (float8_weight_only ())
@@ -44,11 +46,11 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
4446class TestAffineQuantized (TestCase ):
4547 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
4648 def test_tensor_core_layout_transpose (self ):
47- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
48- t = l .weight
49+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
50+ t = linear .weight
4951 shape = t .shape
5052 apply_int4_weight_only_quant = int4_weight_only (group_size = 32 )
51- ql = apply_int4_weight_only_quant (l )
53+ ql = apply_int4_weight_only_quant (linear )
5254 aqt = ql .weight
5355 aqt_shape = aqt .shape
5456 self .assertEqual (aqt_shape , shape )
@@ -64,8 +66,8 @@ def test_tensor_core_layout_transpose(self):
6466 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
6567 @common_utils .parametrize ("apply_quant" , get_quantization_functions (True , True ))
6668 def test_weights_only (self , apply_quant ):
67- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
68- ql = apply_quant (l )
69+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
70+ ql = apply_quant (linear )
6971 with tempfile .NamedTemporaryFile () as f :
7072 torch .save (ql .state_dict (), f )
7173 f .seek (0 )
@@ -78,33 +80,32 @@ def test_weights_only(self, apply_quant):
7880 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
7981 @common_utils .parametrize ("apply_quant" , get_quantization_functions (False , False ))
8082 def test_to_device (self , apply_quant ):
81- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
82- ql = apply_quant (l )
83+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
84+ ql = apply_quant (linear )
8385 ql .to ("cuda" )
8486
85- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
86- ql = apply_quant (l )
87+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88+ ql = apply_quant (linear )
8789 ql .to (device = "cuda" )
8890
89- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
90- ql = apply_quant (l )
91+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
92+ ql = apply_quant (linear )
9193 ql .cuda ()
9294
9395 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
9496 def test_register_new_dispatch (self ):
97+ from torchao .dtypes import AffineQuantizedTensor , to_affine_quantized_intx
9598 from torchao .dtypes .affine_quantized_tensor_ops import (
96- register_aqt_quantized_linear_dispatch ,
9799 deregister_aqt_quantized_linear_dispatch ,
100+ register_aqt_quantized_linear_dispatch ,
98101 )
99- from torchao .dtypes import to_affine_quantized_intx
100- from torchao .dtypes import AffineQuantizedTensor
101102 from torchao .quantization .quant_primitives import MappingType
102103
103104 def dispatch_condition (input_tensor , weight_tensor , bias ):
104105 return (
105- isinstance (weight_tensor , AffineQuantizedTensor ) and
106- weight_tensor .quant_min == 0 and
107- weight_tensor .quant_max == 2 ** 6 - 1
106+ isinstance (weight_tensor , AffineQuantizedTensor )
107+ and weight_tensor .quant_min == 0
108+ and weight_tensor .quant_max == 2 ** 6 - 1
108109 )
109110
110111 def impl (input_tensor , weight_tensor , bias ):
@@ -115,23 +116,35 @@ def impl(input_tensor, weight_tensor, bias):
115116 register_aqt_quantized_linear_dispatch (dispatch_condition , impl )
116117
117118 def apply_uint6_weight_only_quant (linear ):
118- linear .weight = torch .nn .Parameter (to_affine_quantized_intx (linear .weight , MappingType .ASYMMETRIC , (1 , linear .weight .shape [- 1 ]), torch .uint8 , 0 , 2 ** 6 - 1 ), requires_grad = False )
119+ linear .weight = torch .nn .Parameter (
120+ to_affine_quantized_intx (
121+ linear .weight ,
122+ MappingType .ASYMMETRIC ,
123+ (1 , linear .weight .shape [- 1 ]),
124+ torch .uint8 ,
125+ 0 ,
126+ 2 ** 6 - 1 ,
127+ ),
128+ requires_grad = False ,
129+ )
119130 return linear
120131
121- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
122- apply_uint6_weight_only_quant (l )
132+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
133+ apply_uint6_weight_only_quant (linear )
123134
124135 example_input = torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" )
125- with self .assertRaisesRegex (AssertionError , "dispatching to my impl for uint6 weight only quant" ):
126- l (example_input )
136+ with self .assertRaisesRegex (
137+ AssertionError , "dispatching to my impl for uint6 weight only quant"
138+ ):
139+ linear (example_input )
127140
128141 deregister_aqt_quantized_linear_dispatch (dispatch_condition )
129142
130143 @common_utils .parametrize ("apply_quant" , get_quantization_functions (True , True ))
131144 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
132145 def test_print_quantized_module (self , apply_quant ):
133- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
134- ql = apply_quant (l )
146+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
147+ ql = apply_quant (linear )
135148 assert "AffineQuantizedTensor" in str (ql )
136149
137150
@@ -143,20 +156,25 @@ class TestAffineQuantizedBasic(TestCase):
143156 @common_utils .parametrize ("device" , COMMON_DEVICES )
144157 @common_utils .parametrize ("dtype" , COMMON_DTYPES )
145158 def test_flatten_unflatten (self , apply_quant , device , dtype ):
146- l = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
147- ql = apply_quant (l )
159+ linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
160+ ql = apply_quant (linear )
148161 lp_tensor = ql .weight
149162 tensor_data_name_dict , tensor_attributes = lp_tensor .__tensor_flatten__ ()
150- tensor_data_dict = {name : getattr (lp_tensor , name ) for name in tensor_data_name_dict }
163+ tensor_data_dict = {
164+ name : getattr (lp_tensor , name ) for name in tensor_data_name_dict
165+ }
151166 outer_size = lp_tensor .size ()
152167 outer_stride = lp_tensor .stride ()
153- reconstructed = type (lp_tensor ).__tensor_unflatten__ (tensor_data_dict , tensor_attributes , outer_size , outer_stride )
168+ reconstructed = type (lp_tensor ).__tensor_unflatten__ (
169+ tensor_data_dict , tensor_attributes , outer_size , outer_stride
170+ )
154171 example_inputs = (torch .randn (32 , 128 , dtype = dtype , device = device ),)
155172 ref = ql (* example_inputs )
156173 ql .weight = torch .nn .Parameter (reconstructed , requires_grad = False )
157174 reconstruct_res = ql (* example_inputs )
158175 self .assertEqual (reconstruct_res , ref )
159176
177+
160178common_utils .instantiate_parametrized_tests (TestAffineQuantized )
161179common_utils .instantiate_parametrized_tests (TestAffineQuantizedBasic )
162180
0 commit comments