66 * then we apply equalization scale to linear activation with to_weight_tensor_with_linear_activation_scale_metadata (input activation will be divided by equalization_scale), and then call F.linear with
77 scaled input activation and quantized weight (so we can reuse the efficient quantized linear kernels used by quantized weight)
88"""
9- import torch
9+
1010import copy
1111
12+ import torch
1213import torch .nn .functional as F
1314from torch import Tensor
15+
1416from torchao .dtypes import (
15- to_affine_quantized_intx_static ,
16- to_affine_quantized_floatx_static ,
1717 Float8Layout ,
18+ to_affine_quantized_floatx_static ,
19+ to_affine_quantized_intx_static ,
1820)
19- from torchao .quantization .utils import compute_error
20- from torchao .quantization import quantize_
21- from torchao .quantization import to_weight_tensor_with_linear_activation_scale_metadata
22- from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
23- from torchao .quantization .observer import (
24- AffineQuantizedMinMaxObserver ,
21+ from torchao .quantization import (
22+ quantize_ ,
23+ to_weight_tensor_with_linear_activation_scale_metadata ,
2524)
2625from torchao .quantization .granularity import (
2726 PerAxis ,
2827 PerTensor ,
2928)
29+ from torchao .quantization .observer import (
30+ AffineQuantizedMinMaxObserver ,
31+ )
32+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
3033from torchao .quantization .quant_primitives import (
3134 MappingType ,
32- FP8_TYPES ,
3335)
36+ from torchao .quantization .utils import compute_error
3437
3538
3639class ObservedLinear (torch .nn .Linear ):
37- def __init__ (self , in_features : int , out_features : int , act_obs : torch .nn .Module , weight_obs : torch .nn .Module , bias : bool = True , device = None , dtype = None ):
40+ def __init__ (
41+ self ,
42+ in_features : int ,
43+ out_features : int ,
44+ act_obs : torch .nn .Module ,
45+ weight_obs : torch .nn .Module ,
46+ bias : bool = True ,
47+ device = None ,
48+ dtype = None ,
49+ ):
3850 super ().__init__ (in_features , out_features , bias , device , dtype )
3951 self .act_obs = act_obs
4052 self .weight_obs = weight_obs
@@ -46,11 +58,20 @@ def forward(self, input: Tensor):
4658
4759 @classmethod
4860 def from_float (cls , float_linear , act_obs , weight_obs ):
49- observed_linear = cls (float_linear .in_features , float_linear .out_features , act_obs , weight_obs , False , device = float_linear .weight .device , dtype = float_linear .weight .dtype )
61+ observed_linear = cls (
62+ float_linear .in_features ,
63+ float_linear .out_features ,
64+ act_obs ,
65+ weight_obs ,
66+ False ,
67+ device = float_linear .weight .device ,
68+ dtype = float_linear .weight .dtype ,
69+ )
5070 observed_linear .weight = float_linear .weight
5171 observed_linear .bias = float_linear .bias
5272 return observed_linear
5373
74+
5475def insert_observers_ (model , act_obs , weight_obs ):
5576 _is_linear = lambda m , fqn : isinstance (m , torch .nn .Linear )
5677
@@ -61,22 +82,39 @@ def replacement_fn(m):
6182
6283 _replace_with_custom_fn_if_matches_filter (model , replacement_fn , _is_linear )
6384
85+
6486# converting observed linear module to linear module with quantzied weights (and quantized activations)
6587# with tensor subclasses
6688def apply_awq (target_dtype : torch .dtype ):
6789 # target_dtype = torch.uint8
6890 def _apply_awq_to_linear (observed_linear ):
6991 # weight quantization
7092 weight_scale , weight_zero_point = observed_linear .weight_obs .calculate_qparams ()
93+
7194 def weight_quant_func (weight ):
7295 block_size = (1 , weight .shape [1 ])
7396 if target_dtype == torch .uint8 :
74- return to_affine_quantized_intx_static (weight , weight_scale , weight_zero_point , block_size , target_dtype )
97+ return to_affine_quantized_intx_static (
98+ weight , weight_scale , weight_zero_point , block_size , target_dtype
99+ )
75100 elif target_dtype == torch .float8_e4m3fn :
76- return to_affine_quantized_floatx_static (weight , weight_scale , block_size , target_dtype , Float8Layout (mm_config = None ))
101+ return to_affine_quantized_floatx_static (
102+ weight ,
103+ weight_scale ,
104+ block_size ,
105+ target_dtype ,
106+ Float8Layout (mm_config = None ),
107+ )
77108 else :
78109 raise ValueError (f"Unsupported target dtype { target_dtype } " )
79- linear = torch .nn .Linear (observed_linear .in_features , observed_linear .out_features , False , device = observed_linear .weight .device , dtype = observed_linear .weight .dtype )
110+
111+ linear = torch .nn .Linear (
112+ observed_linear .in_features ,
113+ observed_linear .out_features ,
114+ False ,
115+ device = observed_linear .weight .device ,
116+ dtype = observed_linear .weight .dtype ,
117+ )
80118 linear .weight = observed_linear .weight
81119 linear .bias = observed_linear .bias
82120
@@ -86,16 +124,22 @@ def weight_quant_func(weight):
86124 equalization_scale , _ = observed_linear .act_obs .calculate_qparams ()
87125 equalization_scale = torch .ones_like (equalization_scale )
88126
89- linear .weight = torch .nn .Parameter (weight_quant_func (linear .weight * equalization_scale ), requires_grad = False )
127+ linear .weight = torch .nn .Parameter (
128+ weight_quant_func (linear .weight * equalization_scale ), requires_grad = False
129+ )
90130
91- linear .weight = torch .nn .Parameter (to_weight_tensor_with_linear_activation_scale_metadata (linear .weight , equalization_scale ), requires_grad = False )
131+ linear .weight = torch .nn .Parameter (
132+ to_weight_tensor_with_linear_activation_scale_metadata (
133+ linear .weight , equalization_scale
134+ ),
135+ requires_grad = False ,
136+ )
92137
93138 return linear
94139
95140 return _apply_awq_to_linear
96141
97142
98-
99143######## Test ##########
100144class ToyLinearModel (torch .nn .Module ):
101145 def __init__ (self , m = 64 , n = 32 , k = 64 ):
@@ -104,7 +148,11 @@ def __init__(self, m=64, n=32, k=64):
104148 self .linear2 = torch .nn .Linear (k , n , bias = False )
105149
106150 def example_inputs (self , batch_size = 1 , dtype = torch .float32 , device = "cpu" ):
107- return (torch .randn (batch_size , self .linear1 .in_features , dtype = dtype , device = device ),)
151+ return (
152+ torch .randn (
153+ batch_size , self .linear1 .in_features , dtype = dtype , device = device
154+ ),
155+ )
108156
109157 def forward (self , x ):
110158 x = self .linear1 (x )
@@ -119,16 +167,24 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
119167 dtype = torch .bfloat16
120168 m = ToyLinearModel ().eval ().to (dtype ).to ("cuda" )
121169
122- m_for_test = copy .deepcopy (m )
123-
124170 m_bf16 = copy .deepcopy (m )
125171 example_inputs = m .example_inputs (dtype = dtype , device = "cuda" )
126172 print ("example inputs shape:" , example_inputs [0 ].shape )
127173
128- m_bf16 = torch .compile (m_bf16 , mode = 'max-autotune' )
129-
130- act_obs = AffineQuantizedMinMaxObserver (mapping_type , target_dtype , granularity_type = PerTensor (), eps = torch .finfo (torch .float32 ).eps )
131- weight_obs = AffineQuantizedMinMaxObserver (mapping_type , target_dtype , granularity_type = PerAxis (axis = 0 ), eps = torch .finfo (torch .float32 ).eps )
174+ m_bf16 = torch .compile (m_bf16 , mode = "max-autotune" )
175+
176+ act_obs = AffineQuantizedMinMaxObserver (
177+ mapping_type ,
178+ target_dtype ,
179+ granularity_type = PerTensor (),
180+ eps = torch .finfo (torch .float32 ).eps ,
181+ )
182+ weight_obs = AffineQuantizedMinMaxObserver (
183+ mapping_type ,
184+ target_dtype ,
185+ granularity_type = PerAxis (axis = 0 ),
186+ eps = torch .finfo (torch .float32 ).eps ,
187+ )
132188
133189 before_quant = m (* example_inputs )
134190
@@ -137,9 +193,9 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
137193 for _ in range (10 ):
138194 m (* example_inputs )
139195
140- after_obs = m (* example_inputs )
196+ m (* example_inputs )
141197
142- m2 = copy .deepcopy (m )
198+ copy .deepcopy (m )
143199
144200 is_observed_linear = lambda m , fqn : isinstance (m , ObservedLinear )
145201
0 commit comments