diff --git a/neural_compressor/common/utils/constants.py b/neural_compressor/common/utils/constants.py index f0ddc3b442b..3fbe9c50ac8 100644 --- a/neural_compressor/common/utils/constants.py +++ b/neural_compressor/common/utils/constants.py @@ -36,6 +36,7 @@ TEQ = "teq" # pragma: no cover AUTOROUND = "autoround" FP8_QUANT = "fp8_quant" +MIX_PRECISION = "mix_precision" # options import datetime diff --git a/neural_compressor/torch/algorithms/mix_precision/__init__.py b/neural_compressor/torch/algorithms/mix_precision/__init__.py new file mode 100644 index 00000000000..084e1c44e0f --- /dev/null +++ b/neural_compressor/torch/algorithms/mix_precision/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.torch.algorithms.mix_precision.half_precision_convert import HalfPrecisionConverter +from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper diff --git a/neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py b/neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py new file mode 100644 index 00000000000..83fb197b6e8 --- /dev/null +++ b/neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py @@ -0,0 +1,88 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Half-precision Convert for Torch Modules.""" + +from typing import Dict, Tuple + +import torch + +from neural_compressor.common import logger +from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper +from neural_compressor.torch.utils import get_device + + +class HalfPrecisionConverter: + """Converter Class for FP16 and BF16.""" + + dtype_mapping = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs): + """Initialize the Half-precision Converter with config. + + Args: + configs_mapping (Dict): config class for mix-precision. + """ + self.configs_mapping = configs_mapping + self.device = get_device() + + def convert(self, model: torch.nn.Module): + """Convert to FP16 or BF16 model. + + Args: + model (torch.nn.Module): the input model. + + Returns: + mix_precision_model (torch.nn.Module): model with mix-precision. + """ + if len(self.configs_mapping) > 0: + logger.info("Convert operators to half-precision") + + if next(model.parameters()).is_cuda: + self.device = "cuda" + elif next(model.parameters()).is_cpu: + self.device = "cpu" + + mix_precision_model = self._wrap_half_precision_model(model) + mix_precision_model.to(self.device) + + return mix_precision_model + + def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""): + """Wrap and replace half-precision target modules. + + Args: + model (torch.nn.Module): the input module. + prefix (str): the name prefix for named children. + + Returns: + model (torch.nn.Module): the model whose target modules have been wrapped. + """ + for name, child in model.named_children(): + op_name = prefix + "." + name if prefix != "" else name + for op_info, config in self.configs_mapping.items(): + if op_name == op_info[0] and config.dtype in ("fp16", "bf16"): + child = HalfPrecisionModuleWrapper( + module=child, device=self.device, dtype=self.dtype_mapping[config.dtype] + ) + else: + self._wrap_half_precision_model(child, op_name) + setattr(model, name, child) + + return model diff --git a/neural_compressor/torch/algorithms/mix_precision/module_wrappers.py b/neural_compressor/torch/algorithms/mix_precision/module_wrappers.py new file mode 100644 index 00000000000..7e8f0758515 --- /dev/null +++ b/neural_compressor/torch/algorithms/mix_precision/module_wrappers.py @@ -0,0 +1,38 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Half-precision Wrapper for Torch Modules.""" + +import torch + + +class HalfPrecisionModuleWrapper(torch.nn.Module): + """FP16 or BF16 Module Wrapper Class.""" + + def __init__(self, module, device="cpu", dtype=torch.float16): + """Init a HalfPrecisionModuleWrapper object.""" + super(HalfPrecisionModuleWrapper, self).__init__() + self.add_module("module", module) + self.device = device + self.dtype = dtype + self.weight = self.module.weight if hasattr(self.module, "weight") else None + self.bias = self.module.bias if hasattr(self.module, "bias") else None + + def forward(self, X): + """Convert dtype.""" + with torch.autocast(device_type=self.device, dtype=self.dtype): + X = self.module(X) + return X.float() diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index b89bec51350..38cf0d3b903 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -34,6 +34,9 @@ FP8Config, get_default_fp8_config, get_default_fp8_config_set, + MixPrecisionConfig, + get_default_mix_precision_config, + get_default_mix_precision_config_set, get_woq_tuning_config, DynamicQuantConfig, get_default_dynamic_config, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 0334121b5ae..c80a0e90934 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -24,6 +24,7 @@ FP8_QUANT, GPTQ, HQQ, + MIX_PRECISION, RTN, SMOOTH_QUANT, STATIC_QUANT, @@ -36,6 +37,7 @@ FP8Config, GPTQConfig, HQQConfig, + MixPrecisionConfig, RTNConfig, SmoothQuantConfig, StaticQuantConfig, @@ -518,3 +520,17 @@ def fp8_quant_entry( model.qconfig = configs_mapping model.save = MethodType(save, model) return model + + +###################### Mixed Precision Algo Entry ################################## +@register_algo(MIX_PRECISION) +def mix_precision_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs +) -> torch.nn.Module: + # only support fp16 and bf16 now, more types might be added later + from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter + + half_precision_converter = HalfPrecisionConverter(configs_mapping, *args, **kwargs) + mix_precision_model = half_precision_converter.convert(model) + + return mix_precision_model diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 62b98b83a34..820103a5f3b 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -36,6 +36,7 @@ FP8_QUANT, GPTQ, HQQ, + MIX_PRECISION, OP_NAME_OR_MODULE_TYPE, RTN, SMOOTH_QUANT, @@ -1209,6 +1210,81 @@ def get_default_fp8_config_set() -> FP8Config: return FP8Config.get_config_set_for_tuning() +######################## MixPrecision Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION) +class MixPrecisionConfig(BaseConfig): + """Config class for mix-precision.""" + + name = MIX_PRECISION + supported_configs: List[OperatorConfig] = [] + params_list = [ + "dtype", + ] + supported_half_precision_ops = ( + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + ) + + def __init__( + self, + dtype: Union[str, List[str]] = "fp16", + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init MixPrecision config. + + Args: + """ + super().__init__(white_list=white_list) + self.dtype = dtype + self._post_init() + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + mix_precision_config = MixPrecisionConfig( + dtype=["fp16", "bf16", "fp32"], + ) + operators = cls.supported_half_precision_ops + supported_configs.append(OperatorConfig(config=mix_precision_config, operators=operators)) + cls.supported_configs = supported_configs + + @staticmethod + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: + white_list = tuple(MixPrecisionConfig.supported_half_precision_ops) + filter_result = [] + for op_name, module in model.named_modules(): + if isinstance(module, white_list): + pair = (op_name, type(module).__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result + + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "MixPrecisionConfig", List["MixPrecisionConfig"]]: + # TODO fwk owner needs to update it. + return MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"]) + + +def get_default_mix_precision_config() -> MixPrecisionConfig: + """Generate the default mix-precision config. + + Returns: + the default mix-precision config. + """ + return MixPrecisionConfig() + + +def get_default_mix_precision_config_set() -> MixPrecisionConfig: + """Generate the default mix-precision config set. + + Returns: + the default mix-precision config. + """ + return MixPrecisionConfig.get_config_set_for_tuning() + + ##################### Algo Configs End ################################### diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index 73001e9797c..268dc6b44e7 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -7,7 +7,13 @@ import transformers from neural_compressor.common import logger -from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set +from neural_compressor.torch.quantization import ( + MixPrecisionConfig, + RTNConfig, + TuningConfig, + autotune, + get_all_config_set, +) from neural_compressor.torch.utils import constants FAKE_DOUBLE_QUANT_CONFIGS = { @@ -332,6 +338,74 @@ def eval_acc_fn(model): ) self.assertIsNone(best_model) + @reset_tuning_target + def test_autotune_mix_precision_default(self): + from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper + + baseline = [1] + acc_res_lst = baseline + [0.9, 0.99, 1] + + def eval_acc_fn(model): + res = acc_res_lst.pop(0) + return res + + custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])], max_trials=3) + best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) + + self.assertIsNotNone(best_model) + self.assertTrue(isinstance(best_model.fc1, HalfPrecisionModuleWrapper)) + self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper)) + self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper)) + + @reset_tuning_target + def test_autotune_mix_precision_set_op_name(self): + from neural_compressor.common.base_config import ComposableConfig, config_registry + from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper + + baseline = [1] + acc_res_lst = baseline + [0.9, 1.1] + + def eval_acc_fn(model): + res = acc_res_lst.pop(0) + return res + + config1 = { + "mix_precision": { + "global": { + "dtype": "bf16", + }, + "local": { + "fc2": { + "dtype": "fp32", + } + }, + } + } + config2 = { + "mix_precision": { + "global": { + "dtype": "fp16", + }, + "local": { + "fc1": { + "dtype": "fp32", + } + }, + } + } + + registered_configs = config_registry.get_cls_configs() + config1 = ComposableConfig.from_dict(config1, config_registry=registered_configs["torch"]) + config2 = ComposableConfig.from_dict(config2, config_registry=registered_configs["torch"]) + + custom_tune_config = TuningConfig(config_set=[config1, config2], max_trials=2) + best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) + + self.assertIsNotNone(best_model) + self.assertTrue(isinstance(best_model.fc1, torch.nn.Linear)) + self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper)) + self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper)) + if __name__ == "__main__": unittest.main()