diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index e34aae63e59..35bf5ae496a 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1242,6 +1242,8 @@ def _combine_capability(self, bf16_ops, q_capability): q_capability["opwise"][bf16_op] = [bf16_config, fp32_config] if bf16_op[1] not in q_capability["optypewise"]: q_capability["optypewise"][bf16_op[1]] = [bf16_config, fp32_config] + if bf16_op[1] in q_capability["optypewise"] and bf16_config not in q_capability["optypewise"][bf16_op[1]]: + q_capability["optypewise"][bf16_op[1]].append(bf16_config) return q_capability def get_fused_list(self, model): @@ -3579,6 +3581,16 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): return q_model self.tune_cfg["fx_sub_module_list"] = self.sub_module_list + + # BF16 fallback + if ( + len(self.tune_cfg["bf16_ops_list"]) > 0 + and self.version.release >= Version("1.11.0").release + and self.use_bf16 + and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1") + ): # pragma: no cover + q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) + if self.approach == "quant_aware_training": q_model._model.train() if self.sub_module_list is None: @@ -3665,14 +3677,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): self.sub_module_list, q_model._model, prefix="", custom_config=self.prepare_custom_config_dict ) - if ( - len(self.tune_cfg["bf16_ops_list"]) > 0 - and self.version.release >= Version("1.11.0").release - and self.use_bf16 - and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1") - ): # pragma: no cover - q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) - self.fused_dict = self.get_fused_list(q_model.model) q_model.is_quantized = True q_model.q_config = copy.deepcopy(self.tune_cfg) diff --git a/neural_compressor/adaptor/pytorch_cpu.yaml b/neural_compressor/adaptor/pytorch_cpu.yaml index f815c5c7f18..fafad6f860b 100644 --- a/neural_compressor/adaptor/pytorch_cpu.yaml +++ b/neural_compressor/adaptor/pytorch_cpu.yaml @@ -19,7 +19,7 @@ name: '1.11' bf16: ['Linear', 'bmm', 'mm', 'baddbmm', 'addmm', 'addbmm', - '_convolution', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell'] + 'Conv1d', 'Conv2d', 'Conv3d', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell'] fp32: ['*'] # `*` means all op types. int8: &1_11_capabilities { 'static': &cap_s8_1_11 { diff --git a/neural_compressor/adaptor/torch_utils/bf16_convert.py b/neural_compressor/adaptor/torch_utils/bf16_convert.py index 917976c2810..b6d5e6d01bd 100644 --- a/neural_compressor/adaptor/torch_utils/bf16_convert.py +++ b/neural_compressor/adaptor/torch_utils/bf16_convert.py @@ -17,7 +17,6 @@ """Bf16 Convert for Torch Utils.""" import torch import torch.nn as nn -from torch.fx import symbolic_trace from ...utils import logger @@ -28,6 +27,7 @@ class BF16ModuleWrapper(nn.Module): def __init__(self, module): """Init a BF16ModuleWrapper object.""" super(BF16ModuleWrapper, self).__init__() + module = module.bfloat16() self.add_module("module", module) self.train(module.training) # WA for TransformerEncoder to access its Linear's weights and bias @@ -38,7 +38,6 @@ def __init__(self, module): def forward(self, X): """Convert dtype.""" X = X.to(torch.bfloat16) - self.module.bfloat16() X = self.module(X) return X.float() @@ -54,12 +53,9 @@ def Convert(model, tune_cfg): mixed_precision_model (object): model with mixed precision. """ bf16_ops_list = tune_cfg["bf16_ops_list"] - fx_sub_module_list = tune_cfg["fx_sub_module_list"] if "fx_sub_module_list" in tune_cfg.keys() else [] if len(bf16_ops_list) > 0: logger.info("Convert operators to bfloat16") mixed_precision_model = _bf16_wrapper_model(model, bf16_ops_list) - if fx_sub_module_list is not None and len(fx_sub_module_list) > 0: - mixed_precision_model = bf16_symbolic_trace(mixed_precision_model, fx_sub_module_list) return mixed_precision_model @@ -67,31 +63,8 @@ def _bf16_wrapper_model(model, bf16_ops_list, prefix=""): for name, child in model.named_children(): op_name = prefix + "." + name if prefix != "" else name for bf16_op_name in bf16_ops_list: - if op_name == bf16_op_name[0]: + if op_name == bf16_op_name[0] or op_name == bf16_op_name[0].split(".module")[0]: child = BF16ModuleWrapper(child) - else: - _bf16_wrapper_model(child, bf16_ops_list, op_name) - setattr(model, name, child) - return model - - -def bf16_symbolic_trace(model, fx_sub_module_list, prefix=""): - """Symbolic trace for bf16 models. - - Args: - model (object): the input model. - fx_sub_module_list (list): _description_ - prefix (str): prefix of op name. - - Returns: - model (object) - """ - for name, child in model.named_children(): - op_name = prefix + "." + name if prefix != "" else name - for fx_sub_module_name in fx_sub_module_list: - if op_name == fx_sub_module_name: - child = symbolic_trace(child) - else: - bf16_symbolic_trace(child, fx_sub_module_list, op_name) - setattr(model, name, child) + setattr(model, name, child) + _bf16_wrapper_model(child, bf16_ops_list, op_name) return model diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py index ea8a18f424e..1bfa38a0bb7 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py @@ -392,21 +392,15 @@ def test_fx_sub_module_quant(self): "Please use PyTroch 1.11 or higher version for mixed precision with pytorch_fx or pytorch backend", ) def test_mix_precision(self): + os.environ["FORCE_BF16"] = "1" model_origin = DynamicControlModel() - # run fx_quant in neural_compressor and save the quantized GraphModule dataset = Datasets("pytorch")["dummy"]((100, 3, 224, 224)) dataloader = DataLoader("pytorch", dataset) set_workspace("./saved") + # fx mode usually has .module suffix due to tracing of the entire model fails, so use conv.* to leverage re.match + ptq_fx_op_name_list["conv.*"] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}} conf = PostTrainingQuantConfig(op_name_dict=ptq_fx_op_name_list) q_model = quantization.fit(model_origin, conf, calib_dataloader=dataloader, calib_func=eval_func) - tune_cfg = q_model.q_config - tune_cfg["op"][("conv.module", "Conv2d")].clear() - tune_cfg["op"][("conv.module", "Conv2d")] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}} - tune_cfg["bf16_ops_list"].append(("conv.module", "Conv2d")) - from neural_compressor.adaptor.torch_utils.bf16_convert import Convert - - q_model._model = Convert(q_model._model, tune_cfg) - self.assertEqual(q_model._model.conv.module.module.weight.dtype, torch.bfloat16) self.assertEqual(q_model._model.conv.module.module.bias.dtype, torch.bfloat16)