@@ -69,8 +69,6 @@ def new_test(self, value=value):
6969
7070
7171class TorchAOBasicTestCase (common_utils .TestCase ):
72- """Basic test case for tensor subclasses
73- """
7472 COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
7573 COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
7674
@@ -142,6 +140,66 @@ def test_linear(self, device, dtype):
142140 lp_res = torch .nn .functional .linear (hp_act_tensor , lp_tensor )
143141 self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
144142
143+
144+ class TorchAOCompileTestCase (common_utils .TestCase ):
145+ COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
146+ COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
147+
148+ TENSOR_SUBCLASS = AffineQuantizedTensor
149+ FACTORY_FN = to_affine_quantized_intx
150+ kwargs = {
151+ "mapping_type" : MappingType .ASYMMETRIC ,
152+ "block_size" : (1 , 32 ),
153+ "target_dtype" : torch .uint8 ,
154+ }
155+ # minimum sqnr for linear operation when the weight is quantized to low precision
156+ # with the above setting
157+ LINEAR_MIN_SQNR = 40
158+ COMPILE_MIN_SQNR = 50
159+
160+ @common_utils .parametrize ("device" , COMMON_DEVICES )
161+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
162+ def test_input_output_tensor_subclass (self , device , dtype ):
163+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
164+ lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
165+ def f (tensor ):
166+ return tensor
167+
168+ ref = f (lp_tensor )
169+ f = torch .compile (f )
170+ compiled = f (lp_tensor )
171+ self .assertTrue (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
172+ self .assertEqual (ref .dequantize (), compiled .dequantize ())
173+
174+ @common_utils .parametrize ("device" , COMMON_DEVICES )
175+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
176+ def test_input_tensor_subclass (self , device , dtype ):
177+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
178+ lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
179+ def f (tensor ):
180+ return tensor .dequantize ()
181+
182+ ref = f (lp_tensor )
183+ f = torch .compile (f )
184+ compiled = f (lp_tensor )
185+ self .assertFalse (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
186+ self .assertEqual (ref , compiled )
187+
188+ @common_utils .parametrize ("device" , COMMON_DEVICES )
189+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
190+ def test_output_tensor_subclass (self , device , dtype ):
191+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
192+ def f (hp_tensor ):
193+ return self .FACTORY_FN (hp_tensor , ** self .kwargs )
194+
195+ ref = f (hp_tensor )
196+ f = torch .compile (f )
197+ compiled = f (hp_tensor )
198+ self .assertTrue (isinstance (f (hp_tensor ), self .TENSOR_SUBCLASS ))
199+ # bfloat16 seems to result in much larger numerical differences
200+ if dtype != torch .bfloat16 :
201+ self .assertGreater (torchao .quantization .utils .compute_error (ref .dequantize (), compiled .dequantize ()), self .COMPILE_MIN_SQNR )
202+
145203 @common_utils .parametrize ("device" , COMMON_DEVICES )
146204 @common_utils .parametrize ("dtype" , COMMON_DTYPES )
147205 def test_linear_compile (self , device , dtype ):
@@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype):
155213 lp_res = torch .compile (l )(hp_act_tensor )
156214 self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
157215
216+
217+
158218common_utils .instantiate_parametrized_tests (TorchAOBasicTestCase )
219+ common_utils .instantiate_parametrized_tests (TorchAOCompileTestCase )
159220
160221if __name__ == "__main__" :
161222 unittest .main ()
0 commit comments