@@ -244,21 +244,21 @@ class PatchedParallelLMHead(PatchedModuleBase):
244244 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
245245 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
246246 # ParallelLMHead inherits from VocabParallelEmbedding (nn.module) which has a member called
247- # "linear_method " of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
247+ # "quant_method " of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
248248 # (both are not nn.module) and implement an "apply" method by using torch.nn.functional.linear
249249 # (Apply the weights in layer to the input tensor.)
250250 # ParallelLMHead's forward method should not be called because LMHead's weights should be used
251251 # in the sampler. (The forward itself throws RuntimeError exception)
252- # So in order to quantize that linear_method we patch only the "apply" method.
252+ # So in order to quantize that quant_method we patch only the "apply" method.
253253 init_linear (self , mod_extra_config , False )
254- self .orig_linear_apply = self .orig_mod .linear_method .apply
254+ self .orig_linear_quant_apply = self .orig_mod .quant_method .apply
255255 if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
256256 if self .use_qdq or self .fake_quant :
257- self .linear_method .apply = self .apply_qdq
257+ self .quant_method .apply = self .apply_qdq
258258 else :
259- self .linear_method .apply = self .apply_quant
259+ self .quant_method .apply = self .apply_quant
260260 elif (self .quantization_mode == QuantMode .MEASURE ) or (self .quantization_mode == QuantMode .SHAPE ):
261- self .linear_method .apply = self .apply_measure
261+ self .quant_method .apply = self .apply_measure
262262
263263 def apply_quant (self , layer , x , bias ):
264264 qinput = self .quant_input (x )
@@ -284,7 +284,7 @@ def apply_qdq(self, layer, x, bias):
284284
285285 def apply_measure (self , layer , x , bias ):
286286 measure_input ((x ,), observer = self ._mod_extra_config .inputs )
287- output = self .orig_linear_apply (layer , x , bias )
287+ output = self .orig_linear_quant_apply (layer , x , bias )
288288 measure_output ((output ,), self ._mod_extra_config .outputs )
289289 return output
290290
0 commit comments