Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neural_compressor/common/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TEQ = "teq" # pragma: no cover
AUTOROUND = "autoround"
FP8_QUANT = "fp8_quant"
MIX_PRECISION = "mix_precision"

# options
import datetime
Expand Down
19 changes: 19 additions & 0 deletions neural_compressor/torch/algorithms/mix_precision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.algorithms.mix_precision.half_precision_convert import HalfPrecisionConverter
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Half-precision Convert for Torch Modules."""

from typing import Dict, Tuple

import torch

from neural_compressor.common import logger
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
from neural_compressor.torch.utils import get_device


class HalfPrecisionConverter:
"""Converter Class for FP16 and BF16."""

dtype_mapping = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs):
"""Initialize the Half-precision Converter with config.

Args:
configs_mapping (Dict): config class for mix-precision.
"""
self.configs_mapping = configs_mapping
self.device = get_device()

def convert(self, model: torch.nn.Module):
"""Convert to FP16 or BF16 model.

Args:
model (torch.nn.Module): the input model.

Returns:
mix_precision_model (torch.nn.Module): model with mix-precision.
"""
if len(self.configs_mapping) > 0:
logger.info("Convert operators to half-precision")

if next(model.parameters()).is_cuda:
self.device = "cuda"
elif next(model.parameters()).is_cpu:
self.device = "cpu"

mix_precision_model = self._wrap_half_precision_model(model)
mix_precision_model.to(self.device)

return mix_precision_model

def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""):
"""Wrap and replace half-precision target modules.

Args:
model (torch.nn.Module): the input module.
prefix (str): the name prefix for named children.

Returns:
model (torch.nn.Module): the model whose target modules have been wrapped.
"""
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for op_info, config in self.configs_mapping.items():
if op_name == op_info[0] and config.dtype in ("fp16", "bf16"):
child = HalfPrecisionModuleWrapper(
module=child, device=self.device, dtype=self.dtype_mapping[config.dtype]
)
else:
self._wrap_half_precision_model(child, op_name)
setattr(model, name, child)

return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Half-precision Wrapper for Torch Modules."""

import torch


class HalfPrecisionModuleWrapper(torch.nn.Module):
"""FP16 or BF16 Module Wrapper Class."""

def __init__(self, module, device="cpu", dtype=torch.float16):
"""Init a HalfPrecisionModuleWrapper object."""
super(HalfPrecisionModuleWrapper, self).__init__()
self.add_module("module", module)
self.device = device
self.dtype = dtype
self.weight = self.module.weight if hasattr(self.module, "weight") else None
self.bias = self.module.bias if hasattr(self.module, "bias") else None

def forward(self, X):
"""Convert dtype."""
with torch.autocast(device_type=self.device, dtype=self.dtype):
X = self.module(X)
return X.float()
3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
FP8Config,
get_default_fp8_config,
get_default_fp8_config_set,
MixPrecisionConfig,
get_default_mix_precision_config,
get_default_mix_precision_config_set,
get_woq_tuning_config,
DynamicQuantConfig,
get_default_dynamic_config,
Expand Down
16 changes: 16 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FP8_QUANT,
GPTQ,
HQQ,
MIX_PRECISION,
RTN,
SMOOTH_QUANT,
STATIC_QUANT,
Expand All @@ -36,6 +37,7 @@
FP8Config,
GPTQConfig,
HQQConfig,
MixPrecisionConfig,
RTNConfig,
SmoothQuantConfig,
StaticQuantConfig,
Expand Down Expand Up @@ -518,3 +520,17 @@ def fp8_quant_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
return model


###################### Mixed Precision Algo Entry ##################################
@register_algo(MIX_PRECISION)
def mix_precision_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs
) -> torch.nn.Module:
# only support fp16 and bf16 now, more types might be added later
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter

half_precision_converter = HalfPrecisionConverter(configs_mapping, *args, **kwargs)
mix_precision_model = half_precision_converter.convert(model)

return mix_precision_model
76 changes: 76 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
FP8_QUANT,
GPTQ,
HQQ,
MIX_PRECISION,
OP_NAME_OR_MODULE_TYPE,
RTN,
SMOOTH_QUANT,
Expand Down Expand Up @@ -1209,6 +1210,81 @@ def get_default_fp8_config_set() -> FP8Config:
return FP8Config.get_config_set_for_tuning()


######################## MixPrecision Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION)
class MixPrecisionConfig(BaseConfig):
"""Config class for mix-precision."""

name = MIX_PRECISION
supported_configs: List[OperatorConfig] = []
params_list = [
"dtype",
]
supported_half_precision_ops = (
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
)

def __init__(
self,
dtype: Union[str, List[str]] = "fp16",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init MixPrecision config.

Args:
"""
super().__init__(white_list=white_list)
self.dtype = dtype
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
mix_precision_config = MixPrecisionConfig(
dtype=["fp16", "bf16", "fp32"],
)
operators = cls.supported_half_precision_ops
supported_configs.append(OperatorConfig(config=mix_precision_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = tuple(MixPrecisionConfig.supported_half_precision_ops)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "MixPrecisionConfig", List["MixPrecisionConfig"]]:
# TODO fwk owner needs to update it.
return MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])


def get_default_mix_precision_config() -> MixPrecisionConfig:
"""Generate the default mix-precision config.

Returns:
the default mix-precision config.
"""
return MixPrecisionConfig()


def get_default_mix_precision_config_set() -> MixPrecisionConfig:
"""Generate the default mix-precision config set.

Returns:
the default mix-precision config.
"""
return MixPrecisionConfig.get_config_set_for_tuning()


##################### Algo Configs End ###################################


Expand Down
76 changes: 75 additions & 1 deletion test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import transformers

from neural_compressor.common import logger
from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set
from neural_compressor.torch.quantization import (
MixPrecisionConfig,
RTNConfig,
TuningConfig,
autotune,
get_all_config_set,
)
from neural_compressor.torch.utils import constants

FAKE_DOUBLE_QUANT_CONFIGS = {
Expand Down Expand Up @@ -332,6 +338,74 @@ def eval_acc_fn(model):
)
self.assertIsNone(best_model)

@reset_tuning_target
def test_autotune_mix_precision_default(self):
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper

baseline = [1]
acc_res_lst = baseline + [0.9, 0.99, 1]

def eval_acc_fn(model):
res = acc_res_lst.pop(0)
return res

custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])], max_trials=3)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)

self.assertIsNotNone(best_model)
self.assertTrue(isinstance(best_model.fc1, HalfPrecisionModuleWrapper))
self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper))
self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper))

@reset_tuning_target
def test_autotune_mix_precision_set_op_name(self):
from neural_compressor.common.base_config import ComposableConfig, config_registry
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper

baseline = [1]
acc_res_lst = baseline + [0.9, 1.1]

def eval_acc_fn(model):
res = acc_res_lst.pop(0)
return res

config1 = {
"mix_precision": {
"global": {
"dtype": "bf16",
},
"local": {
"fc2": {
"dtype": "fp32",
}
},
}
}
config2 = {
"mix_precision": {
"global": {
"dtype": "fp16",
},
"local": {
"fc1": {
"dtype": "fp32",
}
},
}
}

registered_configs = config_registry.get_cls_configs()
config1 = ComposableConfig.from_dict(config1, config_registry=registered_configs["torch"])
config2 = ComposableConfig.from_dict(config2, config_registry=registered_configs["torch"])

custom_tune_config = TuningConfig(config_set=[config1, config2], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)

self.assertIsNotNone(best_model)
self.assertTrue(isinstance(best_model.fc1, torch.nn.Linear))
self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper))
self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper))


if __name__ == "__main__":
unittest.main()