| 
 | 1 | +#  | 
 | 2 | +# -*- coding: utf-8 -*-  | 
 | 3 | +#  | 
 | 4 | +# Copyright (c) 2024 Intel Corporation  | 
 | 5 | +#  | 
 | 6 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 7 | +# you may not use this file except in compliance with the License.  | 
 | 8 | +# You may obtain a copy of the License at  | 
 | 9 | +#  | 
 | 10 | +#   http://www.apache.org/licenses/LICENSE-2.0  | 
 | 11 | +#  | 
 | 12 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 13 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 15 | +# See the License for the specific language governing permissions and  | 
 | 16 | +# limitations under the License.  | 
 | 17 | +"""Half-precision Convert for Torch Modules."""  | 
 | 18 | + | 
 | 19 | +from typing import Dict, Tuple  | 
 | 20 | + | 
 | 21 | +import torch  | 
 | 22 | + | 
 | 23 | +from neural_compressor.common import logger  | 
 | 24 | +from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper  | 
 | 25 | +from neural_compressor.torch.utils import get_device  | 
 | 26 | + | 
 | 27 | + | 
 | 28 | +class HalfPrecisionConverter:  | 
 | 29 | +    """Converter Class for FP16 and BF16."""  | 
 | 30 | + | 
 | 31 | +    dtype_mapping = {  | 
 | 32 | +        "fp16": torch.float16,  | 
 | 33 | +        "bf16": torch.bfloat16,  | 
 | 34 | +    }  | 
 | 35 | + | 
 | 36 | +    def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs):  | 
 | 37 | +        """Initialize the Half-precision Converter with config.  | 
 | 38 | +
  | 
 | 39 | +        Args:  | 
 | 40 | +            configs_mapping (Dict): config class for mix-precision.  | 
 | 41 | +        """  | 
 | 42 | +        self.configs_mapping = configs_mapping  | 
 | 43 | +        self.device = get_device()  | 
 | 44 | + | 
 | 45 | +    def convert(self, model: torch.nn.Module):  | 
 | 46 | +        """Convert to FP16 or BF16 model.  | 
 | 47 | +
  | 
 | 48 | +        Args:  | 
 | 49 | +            model (torch.nn.Module): the input model.  | 
 | 50 | +
  | 
 | 51 | +        Returns:  | 
 | 52 | +            mix_precision_model (torch.nn.Module): model with mix-precision.  | 
 | 53 | +        """  | 
 | 54 | +        if len(self.configs_mapping) > 0:  | 
 | 55 | +            logger.info("Convert operators to half-precision")  | 
 | 56 | + | 
 | 57 | +        if next(model.parameters()).is_cuda:  | 
 | 58 | +            self.device = "cuda"  | 
 | 59 | +        elif next(model.parameters()).is_cpu:  | 
 | 60 | +            self.device = "cpu"  | 
 | 61 | + | 
 | 62 | +        mix_precision_model = self._wrap_half_precision_model(model)  | 
 | 63 | +        mix_precision_model.to(self.device)  | 
 | 64 | + | 
 | 65 | +        return mix_precision_model  | 
 | 66 | + | 
 | 67 | +    def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""):  | 
 | 68 | +        """Wrap and replace half-precision target modules.  | 
 | 69 | +
  | 
 | 70 | +        Args:  | 
 | 71 | +            model (torch.nn.Module): the input module.  | 
 | 72 | +            prefix (str): the name prefix for named children.  | 
 | 73 | +
  | 
 | 74 | +        Returns:  | 
 | 75 | +            model (torch.nn.Module): the model whose target modules have been wrapped.  | 
 | 76 | +        """  | 
 | 77 | +        for name, child in model.named_children():  | 
 | 78 | +            op_name = prefix + "." + name if prefix != "" else name  | 
 | 79 | +            for op_info, config in self.configs_mapping.items():  | 
 | 80 | +                if op_name == op_info[0] and config.dtype in ("fp16", "bf16"):  | 
 | 81 | +                    child = HalfPrecisionModuleWrapper(  | 
 | 82 | +                        module=child, device=self.device, dtype=self.dtype_mapping[config.dtype]  | 
 | 83 | +                    )  | 
 | 84 | +            else:  | 
 | 85 | +                self._wrap_half_precision_model(child, op_name)  | 
 | 86 | +                setattr(model, name, child)  | 
 | 87 | + | 
 | 88 | +        return model  | 
0 commit comments