Skip to content

Commit d7f99da

Browse files
committed
add more uts
Signed-off-by: yiliu30 <[email protected]>
1 parent de59f73 commit d7f99da

File tree

2 files changed

+80
-32
lines changed

2 files changed

+80
-32
lines changed

neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Rewrite the FP32 operators to FP16 or BF16 operators."""
1515

16+
from collections import defaultdict
1617
from dataclasses import dataclass
1718
from functools import partial
1819
from typing import Any, Callable, Dict, List, Tuple
@@ -50,25 +51,31 @@ class PatternPair:
5051

5152
# key: torch func
5253
# value: the tuple of args
53-
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]]
54+
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, List[Tuple[torch.Tensor, ...]]]
5455

5556

5657
# Align with xiq, as it relay on xiq's set_module_xx capability
5758
FN_ARGS_MAPPING: FuncArgsMappingType = {
58-
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
59-
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
60-
torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias
61-
torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias
62-
torch.matmul: (torch.randn(0, 0), torch.randn(0, 0)), # matmul
63-
torch.matmul: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # matmul
64-
torch.matmul: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # matmul
59+
# Note: ORDER is matter
60+
torch.nn.functional.linear: [
61+
(torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
62+
(torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
63+
],
64+
torch.nn.functional.conv2d: [
65+
(torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias
66+
(torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias
67+
],
68+
torch.matmul: [
69+
(torch.randn(0, 0), torch.randn(0, 0)),
70+
(torch.randn(0, 0, 0), torch.randn(0, 0, 0)),
71+
(torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)),
72+
],
6573
}
6674

6775
# module cls <-> function name
6876
NN_MODULES_TO_NN_FN = {
6977
torch.nn.Linear: torch.nn.functional.linear,
7078
torch.nn.Conv2d: torch.nn.functional.conv2d,
71-
torch.nn.MaxPool2d: torch.nn.functional.max_pool2d,
7279
}
7380

7481
# Use the mapping from xiq
@@ -78,7 +85,10 @@ class PatternPair:
7885

7986

8087
PatternRegistryType: TypeAlias = Dict[TorchFuncType, PatternPair]
81-
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {torch.float16: {}, torch.bfloat16: {}}
88+
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {
89+
torch.float16: defaultdict(list),
90+
torch.bfloat16: defaultdict(list),
91+
}
8292

8393
# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
8494
# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
@@ -108,10 +118,11 @@ def replace_fn_wrapper(fn_args, fn):
108118

109119

110120
def _register_pattern_pair(dtype: torch.dtype) -> None:
111-
for fn, fn_args in FN_ARGS_MAPPING.items():
112-
logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.")
113-
pattern_pair = pattern_factory(fn, fn_args)
114-
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair
121+
for fn, fn_args_lst in FN_ARGS_MAPPING.items():
122+
for fn_args in fn_args_lst:
123+
logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.")
124+
pattern_pair = pattern_factory(fn, fn_args)
125+
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn].append(pattern_pair)
115126
utils.logger.debug(
116127
f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}."
117128
)
@@ -194,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
194205

195206
def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], target_dtype: torch.dtype = torch.float16):
196207
"""Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
197-
for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
198-
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
199-
utils.logger.info("Half precision conversion is done:")
208+
for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
209+
for pattern_pair in pattern_pair_lst:
210+
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
211+
utils.logger.info(f"Half precision conversion({target_dtype}) completed.")
200212
if utils.level_name == "DEBUG": # pragma: no cover
201213
gm.print_readable(True)
202214

@@ -249,5 +261,7 @@ def get_half_precision_node_set(gm, config):
249261
for node in possible_node_set:
250262
if node.target in SUPPORTED_OPERATORS:
251263
half_precision_node_set.add(node)
252-
utils.logger.info(f"Found {len(half_precision_node_set)} nodes to convert to half precision.")
264+
utils.logger.info(
265+
f"Found {len(half_precision_node_set)} nodes to convert to half precision: {half_precision_node_set}"
266+
)
253267
return half_precision_node_set

test/3x/torch/quantization/test_pt2e_quant.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def _is_ipex_imported():
2929
monkeypatch.setattr("neural_compressor.torch.quantization.algorithm_entry.is_ipex_imported", _is_ipex_imported)
3030
monkeypatch.setattr("neural_compressor.torch.export.pt2e_export.is_ipex_imported", _is_ipex_imported)
3131

32-
3332
class TestPT2EQuantization:
3433
def teardown_class(self):
3534
shutil.rmtree("saved_results", ignore_errors=True)
@@ -53,15 +52,15 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
5352
return bar, example_inputs
5453

5554
@staticmethod
56-
def build_model_include_conv_and_linear():
55+
def build_model_include_conv_and_linear(bias=True):
5756
class Model(torch.nn.Module):
58-
def __init__(self):
57+
def __init__(self, bias=True):
5958
super(Model, self).__init__()
60-
self.conv1 = torch.nn.Conv2d(3, 6, 5)
59+
self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=bias)
6160
self.pool = torch.nn.MaxPool2d(2, 2)
62-
self.conv2 = torch.nn.Conv2d(6, 16, 5)
63-
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
64-
self.fc2 = torch.nn.Linear(120, 84)
61+
self.conv2 = torch.nn.Conv2d(6, 16, 5, bias=bias)
62+
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120, bias=bias)
63+
self.fc2 = torch.nn.Linear(120, 84, bias=bias)
6564

6665
def forward(self, x):
6766
x = self.conv1(x)
@@ -74,7 +73,7 @@ def forward(self, x):
7473

7574
return x
7675

77-
model = Model()
76+
model = Model(bias)
7877
example_inputs = (torch.randn(1, 3, 32, 32),)
7978
return model, example_inputs
8079

@@ -286,19 +285,54 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
286285

287286
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
288287
@pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"])
289-
def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, force_not_import_ipex):
288+
@pytest.mark.parametrize("op_name", ["conv1", "fc1"])
289+
@pytest.mark.parametrize("bias", [True, False])
290+
def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name, bias, force_not_import_ipex):
291+
# Test for auto-tune with mixed int8 and 16bits
292+
# Just make sure the pattern matches, not the accuracy.
290293
# config1: int8 for all
291-
# config2: half precision for linear
294+
# config2: half precision for linear/conv
292295
from neural_compressor.torch.quantization.config import INT8StaticQuantConfig
293296
from neural_compressor.torch.quantization.autotune import autotune, TuningConfig
297+
294298
config1 = INT8StaticQuantConfig()
295-
config2 = INT8StaticQuantConfig().set_local("fc1", StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype))
299+
config2 = INT8StaticQuantConfig().set_local(
300+
op_name, StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype)
301+
)
296302
tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1)
303+
eval_result = [1, 1, 2]
304+
297305
def fake_eval_fn(model):
298-
return 1.0
306+
res = eval_result.pop(0)
307+
return res
308+
299309
def run_fn(model):
300310
for i in range(2):
301311
model(*example_inputs)
302-
model, example_inputs = self.build_model_include_conv_and_linear()
312+
313+
model, example_inputs = self.build_model_include_conv_and_linear(bias)
303314
model = export(model, example_inputs=example_inputs)
304-
qmodel = autotune(model=model, tune_config=tune_config, eval_fn=fake_eval_fn,run_fn=run_fn, example_inputs=example_inputs)
315+
qmodel = autotune(
316+
model=model, tune_config=tune_config, eval_fn=fake_eval_fn, run_fn=run_fn, example_inputs=example_inputs
317+
)
318+
319+
# check the half node
320+
expected_node_occurrence = {
321+
# 4 `aten.to` for target op if bias else 3
322+
torch.ops.aten.to.dtype: (3 + int(bias))
323+
}
324+
expected_node_occurrence = {
325+
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()
326+
}
327+
node_in_graph = self.get_node_in_graph(qmodel)
328+
for node, cnt in expected_node_occurrence.items():
329+
assert (
330+
node_in_graph.get(node, 0) == cnt
331+
), f"Node {node} should occur {cnt} times, but {node_in_graph.get(node, 0)}"
332+
# inference
333+
from torch._inductor import config
334+
335+
config.freezing = True
336+
opt_model = torch.compile(qmodel)
337+
out = opt_model(*example_inputs)
338+
assert out is not None

0 commit comments

Comments
 (0)