1414from torchao .quantization .quant_api import (
1515 _replace_with_custom_fn_if_matches_filter ,
1616)
17+ from torchao .quantization .quant_primitives import TorchAODType
18+ from .api import FakeQuantizeConfig
19+ from .fake_quantizer import FakeQuantizer
1720from .utils import (
1821 _fake_quantize_per_channel_group ,
1922 _get_qmin_qmax ,
2023)
2124
2225
26+ class FakeQuantizedEmbedding (torch .nn .Embedding ):
27+ """
28+ General embedding layer with fake quantized weights.
29+
30+ Specific target dtypes, granularity, schemes etc. are specified
31+ through separate configs for weights and activations.
32+
33+ Example usage::
34+
35+ weight_config = FakeQuantizeConfig(
36+ dtype=torch.int4,
37+ group_size=8,
38+ symmetric=True,
39+ )
40+ fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
41+ fq_embedding(torch.LongTensor([3]))
42+ """
43+
44+ def __init__ (
45+ self ,
46+ num_embeddings : int ,
47+ embedding_dim : int ,
48+ padding_idx : Optional [int ] = None ,
49+ max_norm : Optional [float ] = None ,
50+ norm_type : float = 2.0 ,
51+ scale_grad_by_freq : bool = False ,
52+ sparse : bool = False ,
53+ weight_config : Optional [FakeQuantizeConfig ] = None ,
54+ * args ,
55+ ** kwargs ,
56+ ) -> None :
57+ super ().__init__ (
58+ num_embeddings ,
59+ embedding_dim ,
60+ padding_idx ,
61+ max_norm ,
62+ norm_type ,
63+ scale_grad_by_freq ,
64+ sparse ,
65+ * args ,
66+ ** kwargs ,
67+ )
68+ if weight_config is not None :
69+ self .weight_fake_quantizer = FakeQuantizer (weight_config )
70+ else :
71+ self .weight_fake_quantizer = None
72+
73+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
74+ if self .weight_fake_quantizer is not None :
75+ w = self .weight_fake_quantizer (self .weight )
76+ else :
77+ w = self .weight
78+ return F .embedding (
79+ x , w , self .padding_idx , self .max_norm ,
80+ self .norm_type , self .scale_grad_by_freq , self .sparse ,
81+ )
82+
83+
2384# ======================================
2485# | Embedding int4 weight-only QAT |
2586# ======================================
@@ -40,7 +101,7 @@ def __init__(
40101 self .bit_width = 4
41102 self .group_size : int = group_size
42103 self .scale_precision : torch .dtype = scale_precision
43- self .zero_point_precision : torch .dtype = zero_point_precision ,
104+ self .zero_point_precision : torch .dtype = zero_point_precision
44105
45106 def prepare (
46107 self ,
@@ -56,16 +117,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
56117
57118 def replacement_fn (child : torch .nn .Module ) -> torch .nn .Module :
58119 new_embedding = Int4WeightOnlyQATEmbedding (
59- group_size = self .group_size ,
60-
61- # other nn.Embedding args
120+ # nn.Embedding args
62121 num_embeddings = child .num_embeddings ,
63122 embedding_dim = child .embedding_dim ,
64123 padding_idx = child .padding_idx ,
65124 max_norm = child .max_norm ,
66125 norm_type = child .norm_type ,
67126 scale_grad_by_freq = child .scale_grad_by_freq ,
68127 sparse = child .sparse ,
128+ # quantization args
129+ group_size = self .group_size ,
130+ scale_precision = self .scale_precision ,
131+ zero_point_precision = self .zero_point_precision ,
69132 device = child .weight .device ,
70133 )
71134 # In distributed training, the model may be instantiated
@@ -98,28 +161,31 @@ def _convert_helper(self, module: torch.nn.Module):
98161 from torchao ._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
99162 for name , child in module .named_children ():
100163 if isinstance (child , Int4WeightOnlyQATEmbedding ):
164+ group_size = child .weight_fake_quantizer .config .group_size
165+ scale_precision = child .weight_fake_quantizer .config .scale_precision
166+ zero_point_precision = child .weight_fake_quantizer .config .zero_point_precision
101167 quantized_embedding = Int4WeightOnlyEmbedding (
102- group_size = child .group_size ,
103- scale_precision = child .scale_precision ,
104- zero_point_precision = child .zero_point_precision ,
105-
106- # other nn.Embedding args
168+ # nn.Embedding args
107169 num_embeddings = child .num_embeddings ,
108170 embedding_dim = child .embedding_dim ,
109171 padding_idx = child .padding_idx ,
110172 max_norm = child .max_norm ,
111173 norm_type = child .norm_type ,
112174 scale_grad_by_freq = child .scale_grad_by_freq ,
113175 sparse = child .sparse ,
176+ # quantization args
177+ group_size = group_size ,
178+ scale_precision = scale_precision ,
179+ zero_point_precision = zero_point_precision ,
114180 device = child .weight .device ,
115181 )
116182 setattr (module , name , quantized_embedding )
117183
118184 # Load weights and qparams into quantized embedding
119185 (qmin , qmax ) = _get_qmin_qmax (self .bit_width )
120- (s , zp ) = get_group_qparams_symmetric (child .weight , self .bit_width , child . group_size )
186+ (s , zp ) = get_group_qparams_symmetric (child .weight , self .bit_width , group_size )
121187 q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
122- child .weight , s , zp , qmin , qmax , torch .int8 , child . group_size ,
188+ child .weight , s , zp , qmin , qmax , torch .int8 , group_size ,
123189 )
124190 quantized_embedding .weight = q_weight
125191 quantized_embedding .scales = s
@@ -128,7 +194,7 @@ def _convert_helper(self, module: torch.nn.Module):
128194 self ._convert_helper (child )
129195
130196
131- class Int4WeightOnlyQATEmbedding (torch . nn . Embedding ):
197+ class Int4WeightOnlyQATEmbedding (FakeQuantizedEmbedding ):
132198 """
133199 This module implements a embedding layer with int4 fake quantized
134200 grouped per channel weights.
@@ -141,47 +207,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
141207
142208 def __init__ (
143209 self ,
210+ num_embeddings : int ,
211+ embedding_dim : int ,
212+ padding_idx : Optional [int ] = None ,
213+ max_norm : Optional [float ] = None ,
214+ norm_type : float = 2.0 ,
215+ scale_grad_by_freq : bool = False ,
216+ sparse : bool = False ,
144217 group_size : int = 32 ,
145218 scale_precision : torch .dtype = torch .float32 ,
146219 zero_point_precision : torch .dtype = torch .int32 ,
147220 * args ,
148221 ** kwargs ,
149222 ):
150- super ().__init__ (* args , ** kwargs )
151- self .bit_width = 4
152- self .group_size = group_size
153- self .scale_precision = scale_precision
154- self .zero_point_precision = zero_point_precision
155- self ._fake_quant_enabled = True
156-
157- def forward (self , x ):
158- weight = self .weight
159-
160- if self ._fake_quant_enabled :
161- (weight_scales , weight_zp ) = get_group_qparams_symmetric (
162- self .weight , self .bit_width , self .group_size , self .scale_precision ,
163- )
164- # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
165- weight_zp = weight_zp .to (self .zero_point_precision )
166- (weight_qmin , weight_qmax ) = _get_qmin_qmax (self .bit_width )
167- w_fq = _fake_quantize_per_channel_group (
168- self .weight ,
169- weight_scales ,
170- weight_zp ,
171- weight_qmin ,
172- weight_qmax ,
173- self .group_size ,
174- )
175- else :
176- w_fq = self .weight
177-
178- return F .embedding (
179- x , w_fq , self .padding_idx , self .max_norm ,
180- self .norm_type , self .scale_grad_by_freq , self .sparse ,
223+ weight_config = FakeQuantizeConfig (
224+ dtype = TorchAODType .INT4 ,
225+ group_size = group_size ,
226+ is_symmetric = True ,
227+ is_dynamic = True ,
228+ scale_precision = scale_precision ,
229+ zero_point_precision = zero_point_precision ,
230+ )
231+ super ().__init__ (
232+ num_embeddings ,
233+ embedding_dim ,
234+ padding_idx ,
235+ max_norm ,
236+ norm_type ,
237+ scale_grad_by_freq ,
238+ sparse ,
239+ weight_config ,
240+ * args ,
241+ ** kwargs ,
181242 )
182243
183244 def enable_fake_quant (self , enabled : bool = True ):
184- self ._fake_quant_enabled = enabled
245+ self .weight_fake_quantizer . enabled = enabled
185246
186247 def disable_fake_quant (self ):
187248 self .enable_fake_quant (False )
@@ -194,25 +255,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
194255 """
195256 def __init__ (
196257 self ,
197- group_size : int ,
198- scale_precision : torch .dtype ,
199- zero_point_precision : torch .dtype ,
200-
201- # nn.Embedding args
202258 num_embeddings : int ,
203259 embedding_dim : int ,
204260 padding_idx : Optional [int ] = None ,
205261 max_norm : Optional [float ] = None ,
206262 norm_type : float = 2.0 ,
207263 scale_grad_by_freq : bool = False ,
208264 sparse : bool = False ,
265+ group_size : int = 32 ,
266+ scale_precision : torch .dtype = torch .float32 ,
267+ zero_point_precision : torch .dtype = torch .int32 ,
209268 device : torch .device = None ,
210269 ):
211270 super ().__init__ ()
212- self .bit_width = 4
213- self .group_size = group_size
214- self .scale_precision = scale_precision
215- self .zero_point_precision = zero_point_precision
271+
272+ # nn.Embedding args
216273 self .num_embeddings = num_embeddings
217274 self .embedding_dim = embedding_dim
218275 self .padding_idx = padding_idx
@@ -221,6 +278,12 @@ def __init__(
221278 self .scale_grad_by_freq = scale_grad_by_freq
222279 self .sparse = sparse
223280
281+ # quantization args
282+ self .bit_width = 4
283+ self .group_size = group_size
284+ self .scale_precision = scale_precision
285+ self .zero_point_precision = zero_point_precision
286+
224287 # currently storing unpacked int8 weights
225288 self .register_buffer (
226289 "weight" ,
0 commit comments