@@ -233,21 +233,21 @@ class PatchedParallelLMHead(PatchedModuleBase):
233233 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
234234 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
235235 # ParallelLMHead inherits from VocabParallelEmbedding (nn.module) which has a member called
236- # "linear_method " of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
236+ # "quant_method " of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
237237 # (both are not nn.module) and implement an "apply" method by using torch.nn.functional.linear
238238 # (Apply the weights in layer to the input tensor.)
239239 # ParallelLMHead's forward method should not be called because LMHead's weights should be used
240240 # in the sampler. (The forward itself throws RuntimeError exception)
241- # So in order to quantize that linear_method we patch only the "apply" method.
241+ # So in order to quantize that quant_method we patch only the "apply" method.
242242 init_linear (self , mod_extra_config )
243- self .orig_linear_apply = types . MethodType ( mod . linear_method .apply . __func__ , self )
243+ self .orig_linear_quant_apply = self . orig_mod . quant_method .apply
244244 if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
245245 if self .use_qdq or self .fake_quant :
246- self .linear_method .apply = self .apply_qdq
246+ self .quant_method .apply = self .apply_qdq
247247 else :
248- self .linear_method .apply = self .apply_quant
248+ self .quant_method .apply = self .apply_quant
249249 elif (self .quantization_mode == QuantMode .MEASURE ) or (self .quantization_mode == QuantMode .SHAPE ):
250- self .linear_method .apply = self .apply_measure
250+ self .quant_method .apply = self .apply_measure
251251
252252 def apply_quant (self , layer , x , bias ):
253253 qinput = self .quant_input (x )
@@ -273,7 +273,7 @@ def apply_qdq(self, layer, x, bias):
273273
274274 def apply_measure (self , layer , x , bias ):
275275 measure_input ((x ,), observer = self ._mod_extra_config .inputs )
276- output = self .orig_linear_apply (layer , x , bias )
276+ output = self .orig_linear_quant_apply (layer , x , bias )
277277 measure_output ((output ,), self ._mod_extra_config .outputs )
278278 return output
279279
0 commit comments