1- """Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
2- """
1+ """Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs"""
2+
3+ import copy
4+
35import torch
6+
7+ from torchao .quantization .quant_api import (
8+ _replace_with_custom_fn_if_matches_filter ,
9+ int4_weight_only ,
10+ int8_dynamic_activation_int8_weight ,
11+ int8_weight_only ,
12+ quantize_ ,
13+ )
414from torchao .quantization .subclass import (
5- Int8WeightOnlyQuantizedLinearWeight ,
615 Int4WeightOnlyQuantizedLinearWeight ,
16+ Int8WeightOnlyQuantizedLinearWeight ,
717)
818from torchao .utils import (
919 TORCH_VERSION_AT_LEAST_2_4 ,
1020 TORCH_VERSION_AT_LEAST_2_5 ,
21+ unwrap_tensor_subclass ,
1122)
12- from torchao .quantization .quant_api import (
13- int4_weight_only ,
14- int8_weight_only ,
15- int8_dynamic_activation_int8_weight ,
16- quantize_ ,
17- _replace_with_custom_fn_if_matches_filter ,
18- )
19- import copy
20- from torchao .utils import unwrap_tensor_subclass
23+
2124
2225def _int8wo_api (mod , ** kwargs ):
2326 if TORCH_VERSION_AT_LEAST_2_4 :
@@ -27,14 +30,20 @@ def _int8wo_api(mod, **kwargs):
2730 else :
2831 change_linear_weights_to_int8_woqtensors (mod , ** kwargs )
2932
33+
3034def _int8da_int8w_api (mod , ** kwargs ):
3135 if TORCH_VERSION_AT_LEAST_2_4 :
32- quantize_ (mod , int8_dynamic_activation_int8_weight (** kwargs ), set_inductor_config = False )
36+ quantize_ (
37+ mod ,
38+ int8_dynamic_activation_int8_weight (** kwargs ),
39+ set_inductor_config = False ,
40+ )
3341 if not TORCH_VERSION_AT_LEAST_2_5 :
3442 unwrap_tensor_subclass (mod )
3543 else :
3644 change_linear_weights_to_int8_dqtensors (mod , ** kwargs )
3745
46+
3847def _int4wo_api (mod , ** kwargs ):
3948 if TORCH_VERSION_AT_LEAST_2_4 :
4049 kwargs_copy = kwargs .copy ()
@@ -47,31 +56,43 @@ def _int4wo_api(mod, **kwargs):
4756 else :
4857 change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
4958
59+
5060class ToyLinearModel (torch .nn .Module ):
51- """Single linear for m * k * n problem size
52- """
53- def __init__ (self , m = 64 , n = 32 , k = 64 , has_bias = False , dtype = torch .float , device = "cuda" ):
61+ """Single linear for m * k * n problem size"""
62+
63+ def __init__ (
64+ self , m = 64 , n = 32 , k = 64 , has_bias = False , dtype = torch .float , device = "cuda"
65+ ):
5466 super ().__init__ ()
5567 self .m = m
5668 self .dtype = dtype
5769 self .device = device
58- self .linear = torch .nn .Linear (k , n , bias = has_bias ).to (dtype = self .dtype , device = self .device )
70+ self .linear = torch .nn .Linear (k , n , bias = has_bias ).to (
71+ dtype = self .dtype , device = self .device
72+ )
5973
6074 def example_inputs (self ):
61- return (torch .randn (self .m , self .linear .in_features , dtype = self .dtype , device = self .device ),)
75+ return (
76+ torch .randn (
77+ self .m , self .linear .in_features , dtype = self .dtype , device = self .device
78+ ),
79+ )
6280
6381 def forward (self , x ):
6482 x = self .linear (x )
6583 return x
6684
85+
6786def _ref_change_linear_weights_to_int8_dqtensors (model , filter_fn = None , ** kwargs ):
6887 """
6988 The deprecated implementation for int8 dynamic quant API, used as a reference for
7089 numerics and performance
7190 """
72- from torchao .quantization .quant_api import _in_features_greater_than_16
73- from torchao .quantization .quant_api import _is_linear
74- from torchao .quantization .quant_api import _get_subclass_inserter
91+ from torchao .quantization .quant_api import (
92+ _get_subclass_inserter ,
93+ _in_features_greater_than_16 ,
94+ _is_linear ,
95+ )
7596 from torchao .quantization .subclass import Int8DynamicallyQuantizedLinearWeight
7697
7798 if filter_fn is None :
@@ -80,40 +101,54 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
80101 )
81102
82103 _replace_with_custom_fn_if_matches_filter (
83- model , _get_subclass_inserter (Int8DynamicallyQuantizedLinearWeight , enable_parametrization = False , ** kwargs ), filter_fn
104+ model ,
105+ _get_subclass_inserter (
106+ Int8DynamicallyQuantizedLinearWeight , enable_parametrization = False , ** kwargs
107+ ),
108+ filter_fn ,
84109 )
85110
111+
86112def _get_ref_change_linear_weights_to_woqtensors (deprecated_tenosr_subclass ):
87113 def _ref_change_linear_weights_to_woqtensors (model , filter_fn = None , ** kwargs ):
88114 """
89115 The deprecated implementation for weight only quant API, used as a reference for
90116 numerics and performance
91117 """
92- from torchao .quantization .quant_api import _is_linear
93- from torchao .quantization .quant_api import _get_subclass_inserter
118+ from torchao .quantization .quant_api import _get_subclass_inserter , _is_linear
94119
95120 filter_fn = kwargs .pop ("filter_fn" , _is_linear )
96121
97122 _replace_with_custom_fn_if_matches_filter (
98123 model ,
99- _get_subclass_inserter (deprecated_tenosr_subclass , enable_parametrization = True , ** kwargs ),
124+ _get_subclass_inserter (
125+ deprecated_tenosr_subclass , enable_parametrization = True , ** kwargs
126+ ),
100127 filter_fn ,
101128 )
102129
103130 return _ref_change_linear_weights_to_woqtensors
104131
105- _ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int8WeightOnlyQuantizedLinearWeight )
106- _ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
132+
133+ _ref_change_linear_weights_to_int8_woqtensors = (
134+ _get_ref_change_linear_weights_to_woqtensors (Int8WeightOnlyQuantizedLinearWeight )
135+ )
136+ _ref_change_linear_weights_to_int4_woqtensors = (
137+ _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
138+ )
107139
108140
109141torch ._dynamo .config .cache_size_limit = 50000
110142
143+
111144@torch .no_grad
112145def _bench_quantized_tensor_subclass_perf (api , ref_api , M , N , K , kwargs = None ):
113146 if kwargs is None :
114147 kwargs = {}
115148
116- m = ToyLinearModel (M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda" ).eval ()
149+ m = ToyLinearModel (
150+ M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda"
151+ ).eval ()
117152 m_bf16 = copy .deepcopy (m )
118153 m_ref = copy .deepcopy (m )
119154 example_inputs = m .example_inputs ()
@@ -130,26 +165,30 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
130165
131166 # perf comparison
132167 from torchao .utils import benchmark_model
168+
133169 # warmup
134170 WARMUP = 20
135171 RUNS = 100
136172
137173 torch ._dynamo .reset ()
138- m_ref = torch .compile (m_ref , mode = ' max-autotune' , fullgraph = True )
174+ m_ref = torch .compile (m_ref , mode = " max-autotune" , fullgraph = True )
139175 benchmark_model (m_ref , WARMUP , example_inputs )
140176 ref_elapsed_time = benchmark_model (m_ref , RUNS , example_inputs )
141177
142178 torch ._dynamo .reset ()
143- m = torch .compile (m , mode = ' max-autotune' , fullgraph = True )
179+ m = torch .compile (m , mode = " max-autotune" , fullgraph = True )
144180 benchmark_model (m , WARMUP , example_inputs )
145181 elapsed_time = benchmark_model (m , RUNS , example_inputs )
146182
147183 torch ._dynamo .reset ()
148- m_bf16 = torch .compile (m_bf16 , mode = ' max-autotune' , fullgraph = True )
184+ m_bf16 = torch .compile (m_bf16 , mode = " max-autotune" , fullgraph = True )
149185 benchmark_model (m_bf16 , WARMUP , example_inputs )
150186 bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
151187
152- print (f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } " )
188+ print (
189+ f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } "
190+ )
191+
153192
154193if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch .cuda .is_available ():
155194 all_shapes = [
@@ -158,16 +197,25 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
158197
159198 print ("_int8da_int8w_api" )
160199 from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
200+
161201 for M , N , K in all_shapes :
162- _bench_quantized_tensor_subclass_perf (_int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K )
202+ _bench_quantized_tensor_subclass_perf (
203+ _int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K
204+ )
163205
164206 print ("_int8wo_api" )
165207 from torchao .quantization .quant_api import change_linear_weights_to_int8_woqtensors
208+
166209 for M , N , K in all_shapes :
167- _bench_quantized_tensor_subclass_perf (_int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K )
210+ _bench_quantized_tensor_subclass_perf (
211+ _int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K
212+ )
168213
169214 print ("_int4wo_api" )
170215 kwargs = {"groupsize" : 32 }
171216 from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
217+
172218 for M , N , K in all_shapes :
173- _bench_quantized_tensor_subclass_perf (_int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs )
219+ _bench_quantized_tensor_subclass_perf (
220+ _int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs
221+ )
0 commit comments