Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,21 +233,21 @@ class PatchedParallelLMHead(PatchedModuleBase):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
# ParallelLMHead inherits from VocabParallelEmbedding (nn.module) which has a member called
# "linear_method" of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
# "quant_method" of type UnquantizedEmbeddingMethod that inherits from QuantizeMethodBase
# (both are not nn.module) and implement an "apply" method by using torch.nn.functional.linear
# (Apply the weights in layer to the input tensor.)
# ParallelLMHead's forward method should not be called because LMHead's weights should be used
# in the sampler. (The forward itself throws RuntimeError exception)
# So in order to quantize that linear_method we patch only the "apply" method.
# So in order to quantize that quant_method we patch only the "apply" method.
init_linear(self, mod_extra_config)
self.orig_linear_apply = types.MethodType(mod.linear_method.apply.__func__, self)
self.orig_linear_quant_apply = self.orig_mod.quant_method.apply
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
if self.use_qdq or self.fake_quant:
self.linear_method.apply = self.apply_qdq
self.quant_method.apply = self.apply_qdq
else:
self.linear_method.apply = self.apply_quant
self.quant_method.apply = self.apply_quant
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
self.linear_method.apply = self.apply_measure
self.quant_method.apply = self.apply_measure

def apply_quant(self, layer, x, bias):
qinput = self.quant_input(x)
Expand All @@ -273,7 +273,7 @@ def apply_qdq(self, layer, x, bias):

def apply_measure(self, layer, x, bias):
measure_input((x,), observer=self._mod_extra_config.inputs)
output = self.orig_linear_apply(layer, x, bias)
output = self.orig_linear_quant_apply(layer, x, bias)
measure_output((output,), self._mod_extra_config.outputs)
return output

Expand Down