1414from torchao .quantization .quant_api import (
1515 _replace_with_custom_fn_if_matches_filter ,
1616)
17+ from torchao .quantization .quant_primitives import TorchAODType
18+ from .api import FakeQuantizeConfig
19+ from .fake_quantizer import FakeQuantizer
1720from .utils import (
1821 _fake_quantize_per_channel_group ,
1922 _get_qmin_qmax ,
2023)
2124
2225
26+ class FakeQuantizedEmbedding (torch .nn .Embedding ):
27+ """
28+ General embedding layer with fake quantized weights.
29+
30+ Specific target dtypes, granularity, schemes etc. are specified
31+ through separate configs for weights and activations.
32+
33+ Example usage::
34+
35+ weight_config = FakeQuantizeConfig(
36+ dtype=torch.int4,
37+ group_size=8,
38+ symmetric=True,
39+ )
40+ fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
41+ fq_embedding(torch.LongTensor([3]))
42+ """
43+
44+ def __init__ (
45+ self ,
46+ num_embeddings : int ,
47+ embedding_dim : int ,
48+ padding_idx : Optional [int ] = None ,
49+ max_norm : Optional [float ] = None ,
50+ norm_type : float = 2.0 ,
51+ scale_grad_by_freq : bool = False ,
52+ sparse : bool = False ,
53+ weight_config : Optional [FakeQuantizeConfig ] = None ,
54+ * args ,
55+ ** kwargs ,
56+ ) -> None :
57+ super ().__init__ (
58+ num_embeddings ,
59+ embedding_dim ,
60+ padding_idx ,
61+ max_norm ,
62+ norm_type ,
63+ scale_grad_by_freq ,
64+ sparse ,
65+ * args ,
66+ ** kwargs ,
67+ )
68+ if weight_config is not None :
69+ self .weight_fake_quantizer = FakeQuantizer (weight_config )
70+ else :
71+ self .weight_fake_quantizer = None
72+
73+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
74+ if self .weight_fake_quantizer is not None :
75+ w = self .weight_fake_quantizer (self .weight )
76+ else :
77+ w = self .weight
78+ return F .embedding (
79+ x , w , self .padding_idx , self .max_norm ,
80+ self .norm_type , self .scale_grad_by_freq , self .sparse ,
81+ )
82+
83+
2384# ======================================
2485# | Embedding int4 weight-only QAT |
2586# ======================================
@@ -37,10 +98,9 @@ def __init__(
3798 zero_point_precision : torch .dtype = torch .int32 ,
3899 ) -> None :
39100 super ().__init__ ()
40- self .bit_width = 4
41101 self .group_size : int = group_size
42102 self .scale_precision : torch .dtype = scale_precision
43- self .zero_point_precision : torch .dtype = zero_point_precision ,
103+ self .zero_point_precision : torch .dtype = zero_point_precision
44104
45105 def prepare (
46106 self ,
@@ -56,16 +116,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
56116
57117 def replacement_fn (child : torch .nn .Module ) -> torch .nn .Module :
58118 new_embedding = Int4WeightOnlyQATEmbedding (
59- group_size = self .group_size ,
60-
61- # other nn.Embedding args
119+ # nn.Embedding args
62120 num_embeddings = child .num_embeddings ,
63121 embedding_dim = child .embedding_dim ,
64122 padding_idx = child .padding_idx ,
65123 max_norm = child .max_norm ,
66124 norm_type = child .norm_type ,
67125 scale_grad_by_freq = child .scale_grad_by_freq ,
68126 sparse = child .sparse ,
127+ # quantization args
128+ group_size = self .group_size ,
129+ scale_precision = self .scale_precision ,
130+ zero_point_precision = self .zero_point_precision ,
69131 device = child .weight .device ,
70132 )
71133 # In distributed training, the model may be instantiated
@@ -98,28 +160,31 @@ def _convert_helper(self, module: torch.nn.Module):
98160 from torchao ._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
99161 for name , child in module .named_children ():
100162 if isinstance (child , Int4WeightOnlyQATEmbedding ):
163+ group_size = child .weight_fake_quantizer .config .group_size
164+ scale_precision = child .weight_fake_quantizer .config .scale_precision
165+ zero_point_precision = child .weight_fake_quantizer .config .zero_point_precision
101166 quantized_embedding = Int4WeightOnlyEmbedding (
102- group_size = child .group_size ,
103- scale_precision = child .scale_precision ,
104- zero_point_precision = child .zero_point_precision ,
105-
106- # other nn.Embedding args
167+ # nn.Embedding args
107168 num_embeddings = child .num_embeddings ,
108169 embedding_dim = child .embedding_dim ,
109170 padding_idx = child .padding_idx ,
110171 max_norm = child .max_norm ,
111172 norm_type = child .norm_type ,
112173 scale_grad_by_freq = child .scale_grad_by_freq ,
113174 sparse = child .sparse ,
175+ # quantization args
176+ group_size = group_size ,
177+ scale_precision = scale_precision ,
178+ zero_point_precision = zero_point_precision ,
114179 device = child .weight .device ,
115180 )
116181 setattr (module , name , quantized_embedding )
117182
118183 # Load weights and qparams into quantized embedding
119- (qmin , qmax ) = _get_qmin_qmax (self . bit_width )
120- (s , zp ) = get_group_qparams_symmetric (child .weight , self . bit_width , child . group_size )
184+ (qmin , qmax ) = _get_qmin_qmax (4 )
185+ (s , zp ) = get_group_qparams_symmetric (child .weight , 4 , group_size )
121186 q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
122- child .weight , s , zp , qmin , qmax , torch .int8 , child . group_size ,
187+ child .weight , s , zp , qmin , qmax , torch .int8 , group_size ,
123188 )
124189 quantized_embedding .weight = q_weight
125190 quantized_embedding .scales = s
@@ -128,7 +193,7 @@ def _convert_helper(self, module: torch.nn.Module):
128193 self ._convert_helper (child )
129194
130195
131- class Int4WeightOnlyQATEmbedding (torch . nn . Embedding ):
196+ class Int4WeightOnlyQATEmbedding (FakeQuantizedEmbedding ):
132197 """
133198 This module implements a embedding layer with int4 fake quantized
134199 grouped per channel weights.
@@ -141,47 +206,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
141206
142207 def __init__ (
143208 self ,
209+ num_embeddings : int ,
210+ embedding_dim : int ,
211+ padding_idx : Optional [int ] = None ,
212+ max_norm : Optional [float ] = None ,
213+ norm_type : float = 2.0 ,
214+ scale_grad_by_freq : bool = False ,
215+ sparse : bool = False ,
144216 group_size : int = 32 ,
145217 scale_precision : torch .dtype = torch .float32 ,
146218 zero_point_precision : torch .dtype = torch .int32 ,
147219 * args ,
148220 ** kwargs ,
149221 ):
150- super ().__init__ (* args , ** kwargs )
151- self .bit_width = 4
152- self .group_size = group_size
153- self .scale_precision = scale_precision
154- self .zero_point_precision = zero_point_precision
155- self ._fake_quant_enabled = True
156-
157- def forward (self , x ):
158- weight = self .weight
159-
160- if self ._fake_quant_enabled :
161- (weight_scales , weight_zp ) = get_group_qparams_symmetric (
162- self .weight , self .bit_width , self .group_size , self .scale_precision ,
163- )
164- # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
165- weight_zp = weight_zp .to (self .zero_point_precision )
166- (weight_qmin , weight_qmax ) = _get_qmin_qmax (self .bit_width )
167- w_fq = _fake_quantize_per_channel_group (
168- self .weight ,
169- weight_scales ,
170- weight_zp ,
171- weight_qmin ,
172- weight_qmax ,
173- self .group_size ,
174- )
175- else :
176- w_fq = self .weight
177-
178- return F .embedding (
179- x , w_fq , self .padding_idx , self .max_norm ,
180- self .norm_type , self .scale_grad_by_freq , self .sparse ,
222+ weight_config = FakeQuantizeConfig (
223+ dtype = TorchAODType .INT4 ,
224+ group_size = group_size ,
225+ is_symmetric = True ,
226+ is_dynamic = True ,
227+ scale_precision = scale_precision ,
228+ zero_point_precision = zero_point_precision ,
229+ )
230+ super ().__init__ (
231+ num_embeddings ,
232+ embedding_dim ,
233+ padding_idx ,
234+ max_norm ,
235+ norm_type ,
236+ scale_grad_by_freq ,
237+ sparse ,
238+ weight_config ,
239+ * args ,
240+ ** kwargs ,
181241 )
182242
183243 def enable_fake_quant (self , enabled : bool = True ):
184- self ._fake_quant_enabled = enabled
244+ self .weight_fake_quantizer . enabled = enabled
185245
186246 def disable_fake_quant (self ):
187247 self .enable_fake_quant (False )
@@ -194,25 +254,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
194254 """
195255 def __init__ (
196256 self ,
197- group_size : int ,
198- scale_precision : torch .dtype ,
199- zero_point_precision : torch .dtype ,
200-
201- # nn.Embedding args
202257 num_embeddings : int ,
203258 embedding_dim : int ,
204259 padding_idx : Optional [int ] = None ,
205260 max_norm : Optional [float ] = None ,
206261 norm_type : float = 2.0 ,
207262 scale_grad_by_freq : bool = False ,
208263 sparse : bool = False ,
264+ group_size : int = 32 ,
265+ scale_precision : torch .dtype = torch .float32 ,
266+ zero_point_precision : torch .dtype = torch .int32 ,
209267 device : torch .device = None ,
210268 ):
211269 super ().__init__ ()
212- self .bit_width = 4
213- self .group_size = group_size
214- self .scale_precision = scale_precision
215- self .zero_point_precision = zero_point_precision
270+
271+ # nn.Embedding args
216272 self .num_embeddings = num_embeddings
217273 self .embedding_dim = embedding_dim
218274 self .padding_idx = padding_idx
@@ -221,6 +277,11 @@ def __init__(
221277 self .scale_grad_by_freq = scale_grad_by_freq
222278 self .sparse = sparse
223279
280+ # quantization args
281+ self .group_size = group_size
282+ self .scale_precision = scale_precision
283+ self .zero_point_precision = zero_point_precision
284+
224285 # currently storing unpacked int8 weights
225286 self .register_buffer (
226287 "weight" ,
@@ -245,7 +306,7 @@ def __init__(
245306
246307 def forward (self , x ):
247308 from torchao ._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper
248- qmin , qmax = _get_qmin_qmax (self . bit_width )
309+ qmin , qmax = _get_qmin_qmax (4 )
249310 w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper (
250311 self .weight ,
251312 self .scale ,
0 commit comments