|  | 
|  | 1 | +# | 
|  | 2 | +# -*- coding: utf-8 -*- | 
|  | 3 | +# | 
|  | 4 | +# Copyright (c) 2021 Intel Corporation | 
|  | 5 | +# | 
|  | 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 7 | +# you may not use this file except in compliance with the License. | 
|  | 8 | +# You may obtain a copy of the License at | 
|  | 9 | +# | 
|  | 10 | +#   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 11 | +# | 
|  | 12 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 13 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 15 | +# See the License for the specific language governing permissions and | 
|  | 16 | +# limitations under the License. | 
|  | 17 | + | 
|  | 18 | +"""Torch.nn.Module Class Defination.""" | 
|  | 19 | +# Note: Do not import this file unless you have already imported torch,  | 
|  | 20 | +# since the model classes inherit torch.nn.Module. | 
|  | 21 | +import torch | 
|  | 22 | +from packaging.version import Version | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +def get_torch_version(): | 
|  | 26 | +    try: | 
|  | 27 | +        torch_version = torch.__version__.split('+')[0] | 
|  | 28 | +    except ValueError as e:  # pragma: no cover | 
|  | 29 | +        assert False, 'Got an unknown version of torch: {}'.format(e) | 
|  | 30 | +    version = Version(torch_version) | 
|  | 31 | +    return version | 
|  | 32 | + | 
|  | 33 | +PT_VERSION = get_torch_version().release | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +class QDQLinear(torch.nn.Module): | 
|  | 37 | +    def __init__(self, module, scale, zero_point, dtype): | 
|  | 38 | +        super().__init__() | 
|  | 39 | +        if PT_VERSION < Version("1.13.0").release: | 
|  | 40 | +            import torch.nn.quantized as nnq | 
|  | 41 | +        else: | 
|  | 42 | +            import torch.ao.nn.quantized as nnq | 
|  | 43 | +        self.add_module('quant', nnq.Quantize(scale, zero_point, dtype)) | 
|  | 44 | +        self.add_module('dequant', nnq.DeQuantize()) | 
|  | 45 | +        self.add_module('module', module) | 
|  | 46 | +        self.qdq_weight() | 
|  | 47 | + | 
|  | 48 | +    def forward(self, X): | 
|  | 49 | +        X = self.quant(X) | 
|  | 50 | +        X = self.dequant(X) | 
|  | 51 | +        X = self.module(X) | 
|  | 52 | +        return X | 
|  | 53 | + | 
|  | 54 | +    def qdq_weight(self): | 
|  | 55 | +        # update weight w/ QDQ | 
|  | 56 | +        from .smooth_quant import quant_dequant_w | 
|  | 57 | +        weith_qdq = quant_dequant_w(self.module) | 
|  | 58 | +        self.module.weight = torch.nn.Parameter(weith_qdq) | 
|  | 59 | + | 
|  | 60 | + | 
|  | 61 | +class SQLinearWrapper(torch.nn.Module): | 
|  | 62 | +    def __init__(self, module, input_scale, input_minmax, dtype=torch.quint8): | 
|  | 63 | +        super().__init__() | 
|  | 64 | +        self.input_scale = input_scale | 
|  | 65 | +        self.dtype = dtype | 
|  | 66 | +        # calculate and only save scale, zero_point to avoid memory usage | 
|  | 67 | +        self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype) | 
|  | 68 | +        self.add_module('sq_linear', module) | 
|  | 69 | +        self.ipex = False  # a flag used for ipex inference | 
|  | 70 | + | 
|  | 71 | +    def forward(self, X): | 
|  | 72 | +        if self.ipex: | 
|  | 73 | +            X = self.sq_linear(X) | 
|  | 74 | +        else: | 
|  | 75 | +            X = torch.mul(X, self.input_scale) | 
|  | 76 | +            X = self.sq_linear(X) | 
|  | 77 | +        return X | 
|  | 78 | + | 
|  | 79 | +    def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8): | 
|  | 80 | +        # calculate scale and zero_point | 
|  | 81 | +        if dtype == torch.quint8: | 
|  | 82 | +            quant_min, quant_max = 0, 255 | 
|  | 83 | +        min_val = torch.min(input_minmax[0] * input_scale) | 
|  | 84 | +        max_val = torch.max(input_minmax[1] * input_scale) | 
|  | 85 | +        # work when min_val bigger than zero. | 
|  | 86 | +        min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | 
|  | 87 | +        max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | 
|  | 88 | +        scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) | 
|  | 89 | +        scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps])) | 
|  | 90 | +        zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) | 
|  | 91 | +        zero_point = torch.clamp(zero_point, quant_min, quant_max) | 
|  | 92 | +        return scale, zero_point | 
|  | 93 | + | 
|  | 94 | +    def _get_weight_scale(self): | 
|  | 95 | +        # get weight scale and zero_point | 
|  | 96 | +        from torch.ao.quantization.observer import default_per_channel_weight_observer | 
|  | 97 | +        obs = default_per_channel_weight_observer() | 
|  | 98 | +        obs(self.sq_linear.weight) | 
|  | 99 | +        scale, _ = obs.calculate_qparams() | 
|  | 100 | +        return scale | 
|  | 101 | + | 
|  | 102 | +    def _recover_sq_linear(self): | 
|  | 103 | +        # remove mul and reset sq_linear for ipex inference | 
|  | 104 | +        scale = self.input_scale.view(1, self.input_scale.shape[0]) | 
|  | 105 | +        with torch.no_grad(): | 
|  | 106 | +            self.sq_linear.weight *= scale | 
0 commit comments