Skip to content

Commit 3dc739a

Browse files
authored
[SW-219274] - Changing the quant method name in lm-head (#150)
1 parent ab265ef commit 3dc739a

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,21 +244,21 @@ class PatchedParallelLMHead(PatchedModuleBase):
244244
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
245245
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
246246
# ParallelLMHead inherits from VocabParallelEmbedding (nn.module) which has a member called
247-
# "linear_method" of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
247+
# "quant_method" of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
248248
# (both are not nn.module) and implement an "apply" method by using torch.nn.functional.linear
249249
# (Apply the weights in layer to the input tensor.)
250250
# ParallelLMHead's forward method should not be called because LMHead's weights should be used
251251
# in the sampler. (The forward itself throws RuntimeError exception)
252-
# So in order to quantize that linear_method we patch only the "apply" method.
252+
# So in order to quantize that quant_method we patch only the "apply" method.
253253
init_linear(self, mod_extra_config, False)
254-
self.orig_linear_apply = self.orig_mod.linear_method.apply
254+
self.orig_linear_quant_apply = self.orig_mod.quant_method.apply
255255
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
256256
if self.use_qdq or self.fake_quant:
257-
self.linear_method.apply = self.apply_qdq
257+
self.quant_method.apply = self.apply_qdq
258258
else:
259-
self.linear_method.apply = self.apply_quant
259+
self.quant_method.apply = self.apply_quant
260260
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
261-
self.linear_method.apply = self.apply_measure
261+
self.quant_method.apply = self.apply_measure
262262

263263
def apply_quant(self, layer, x, bias):
264264
qinput = self.quant_input(x)
@@ -284,7 +284,7 @@ def apply_qdq(self, layer, x, bias):
284284

285285
def apply_measure(self, layer, x, bias):
286286
measure_input((x,), observer=self._mod_extra_config.inputs)
287-
output = self.orig_linear_apply(layer, x, bias)
287+
output = self.orig_linear_quant_apply(layer, x, bias)
288288
measure_output((output,), self._mod_extra_config.outputs)
289289
return output
290290

0 commit comments

Comments
 (0)