Skip to content

Commit b6e5aea

Browse files
changwangssnirda7
andauthored
[SW-219274] - Changing the quant method name in lm-head (#150) (#2132)
* [SW-219274] - Changing the quant method name in lm-head (#150) * Update helper_modules.py --------- Co-authored-by: Nir David <[email protected]>
1 parent c121e4e commit b6e5aea

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,21 @@ class PatchedParallelLMHead(PatchedModuleBase):
233233
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
234234
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
235235
# ParallelLMHead inherits from VocabParallelEmbedding (nn.module) which has a member called
236-
# "linear_method" of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
236+
# "quant_method" of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
237237
# (both are not nn.module) and implement an "apply" method by using torch.nn.functional.linear
238238
# (Apply the weights in layer to the input tensor.)
239239
# ParallelLMHead's forward method should not be called because LMHead's weights should be used
240240
# in the sampler. (The forward itself throws RuntimeError exception)
241-
# So in order to quantize that linear_method we patch only the "apply" method.
241+
# So in order to quantize that quant_method we patch only the "apply" method.
242242
init_linear(self, mod_extra_config)
243-
self.orig_linear_apply = types.MethodType(mod.linear_method.apply.__func__, self)
243+
self.orig_linear_quant_apply = self.orig_mod.quant_method.apply
244244
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
245245
if self.use_qdq or self.fake_quant:
246-
self.linear_method.apply = self.apply_qdq
246+
self.quant_method.apply = self.apply_qdq
247247
else:
248-
self.linear_method.apply = self.apply_quant
248+
self.quant_method.apply = self.apply_quant
249249
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
250-
self.linear_method.apply = self.apply_measure
250+
self.quant_method.apply = self.apply_measure
251251

252252
def apply_quant(self, layer, x, bias):
253253
qinput = self.quant_input(x)
@@ -273,7 +273,7 @@ def apply_qdq(self, layer, x, bias):
273273

274274
def apply_measure(self, layer, x, bias):
275275
measure_input((x,), observer=self._mod_extra_config.inputs)
276-
output = self.orig_linear_apply(layer, x, bias)
276+
output = self.orig_linear_quant_apply(layer, x, bias)
277277
measure_output((output,), self._mod_extra_config.outputs)
278278
return output
279279

0 commit comments

Comments
 (0)