Skip to content

Commit 0dc6a92

Browse files
authored
Introduce quant_level into mixed precision (#950)
* Introduce quant_level into mixed precision Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
1 parent 00e5cb5 commit 0dc6a92

File tree

4 files changed

+195
-39
lines changed

4 files changed

+195
-39
lines changed

neural_compressor/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,8 @@ class MixedPrecisionConfig(object):
16911691
model_name (str, optional): The name of the model. Default value is empty.
16921692
inputs (list, optional): Inputs of model, default is [].
16931693
outputs (list, optional): Outputs of model, default is [].
1694+
quant_level: Support auto, 0 and 1, 0 is conservative(fallback in op type wise),
1695+
1(fallback in op wise), auto (default) is the combination of 0 and 1.
16941696
tuning_criterion (TuningCriterion object, optional): Accuracy tuning settings,
16951697
it won't work if there is no accuracy tuning process.
16961698
accuracy_criterion (AccuracyCriterion object, optional): Accuracy constraint settings,
@@ -1739,6 +1741,7 @@ def __init__(self,
17391741
model_name="",
17401742
inputs=[],
17411743
outputs=[],
1744+
quant_level="auto",
17421745
tuning_criterion=tuning_criterion,
17431746
accuracy_criterion=accuracy_criterion,
17441747
excluded_precisions=[],
@@ -1750,6 +1753,7 @@ def __init__(self,
17501753
self.outputs = outputs
17511754
self.backend = backend
17521755
self.device = device
1756+
self.quant_level = quant_level
17531757
self.excluded_precisions = excluded_precisions
17541758
self.accuracy_criterion = accuracy_criterion
17551759
self.tuning_criterion = tuning_criterion
@@ -1788,6 +1792,16 @@ def model_name(self, model_name):
17881792
if _check_value("model_name", model_name, str):
17891793
self._model_name = model_name
17901794

1795+
@property
1796+
def quant_level(self):
1797+
"""Get the quantization level."""
1798+
return self._quant_level
1799+
1800+
@quant_level.setter
1801+
def quant_level(self, quant_level):
1802+
"""Set the quantization level."""
1803+
self._quant_level = quant_level
1804+
17911805
@property
17921806
def accuracy_criterion(self):
17931807
"""Get the accuracy criterion."""

neural_compressor/strategy/auto_mixed_precision.py

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
"""The auto-mixed precision strategy."""
1919

2020
import copy
21-
import numpy as np
22-
from collections import OrderedDict
21+
from collections import OrderedDict, defaultdict
22+
from itertools import groupby
2323
from .strategy import strategy_registry, TuneStrategy
2424
from ..utils import logger
25-
from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler
25+
from .utils.tuning_sampler import FallbackTuningSampler
2626
from .utils.tuning_structs import OpTuningConfig
2727
from neural_compressor.adaptor.torch_utils.mixed_precision import ipex_mixed_precision
2828

@@ -50,6 +50,7 @@ def _initialize_config(self, conf):
5050
config.domain = getattr(config, 'domain', None)
5151
config.reduce_range = getattr(config, 'reduce_range', None)
5252
config.example_inputs = getattr(config, 'example_inputs', None)
53+
config.quant_level = getattr(config, "quant_level", "auto")
5354
return config
5455

5556
def next_tune_cfg(self):
@@ -79,54 +80,116 @@ def next_tune_cfg(self):
7980
if not target_dtypes:
8081
target_dtypes = ['bf16']
8182
# step1. target_dtype AMAP, collect the ops that support target_dtype
82-
bf16_items_name = []
83+
lower_precision_items_name = []
8384
op_tuning_cfg = {}
8485
for idx, target_dtype in enumerate(target_dtypes):
85-
bf16_items = tuning_space.query_items_by_quant_mode(target_dtype)
86-
if len(bf16_items) == 0 and \
87-
not (idx == len(target_dtypes) - 1 and len(bf16_items_name) == 0):
86+
lower_precision_items = tuning_space.query_items_by_quant_mode(target_dtype)
87+
if len(lower_precision_items) == 0 and \
88+
not (idx == len(target_dtypes) - 1 and len(lower_precision_items_name) == 0):
8889
continue
89-
bf16_items_name = [item.name for item in bf16_items]
90+
lower_precision_items_name = [item.name for item in lower_precision_items]
9091
op_tuning_cfg = deepcopy(initial_op_tuning_cfg)
91-
for op_name_type in bf16_items_name:
92+
for op_name_type in lower_precision_items_name:
9293
op_tuning_cfg[op_name_type] = \
9394
OpTuningConfig(op_name_type[0], op_name_type[1], target_dtype, tuning_space)
9495
calib_sampling_size = 1
9596
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
9697
yield op_tuning_cfg
9798

98-
# step2. fallback
99-
target_dtype = 'fp32'
100-
fallback_items_name_lst = bf16_items_name[::-1]
99+
# step 2, fallback op into fp32
100+
# quant_level:
101+
# auto: op-type-wise -> op-wise
102+
# 0: op-type wise
103+
# 1: op-wise
104+
105+
# if quant level is auto or 0, do op type wise fallback
106+
target_dtype = "fp32"
107+
fallback_items_name_lst = lower_precision_items_name[::-1]
101108
if fallback_items_name_lst:
102-
logger.info(f"Start to fallback op to {target_dtype} one by one.")
103-
self._fallback_started()
104-
op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst)))
109+
logger.info("[Strategy] start fallback op into fp32.")
105110
initial_op_tuning_cfg = deepcopy(op_tuning_cfg)
111+
if self.config.quant_level in ["auto", 0]:
112+
logger.info(f"[Strategy] fallback op into fp32 in op type wise, \
113+
as quant level is {self.config.quant_level}")
114+
for op_tuning_cfg in self.fallback_in_op_type_wise(tuning_space, fallback_items_name_lst,\
115+
deepcopy(initial_op_tuning_cfg), target_dtype):
116+
yield op_tuning_cfg
117+
118+
# if quant level is auto or 1, do op instance fallback
119+
if self.config.quant_level in ["auto", 1]:
120+
logger.info(f"[Strategy] fallback op into fp32 in op wise, \
121+
as quant level is {self.config.quant_level}")
122+
for op_tuning_cfg in self.fallback_in_op_wise(tuning_space, fallback_items_name_lst,\
123+
deepcopy(initial_op_tuning_cfg), target_dtype):
124+
yield op_tuning_cfg
125+
126+
def fallback_in_op_type_wise(self, tuning_space, fallback_items_name_lst, initial_op_tuning_cfg, target_dtype):
127+
"""Fallback op in op type wise.
128+
129+
Args:
130+
tuning_space: tuning space
131+
fallback_items_name_lst: the list of items to be fallback
132+
initial_op_tuning_cfg: initial tuning config
133+
target_dtype: target data type, such as fp32
134+
135+
Yields:
136+
tuning config
137+
"""
138+
fallback_items_name_lst.sort(key=lambda x: x[1])
139+
op_type_groups = groupby(fallback_items_name_lst, key=lambda x: x[1])
140+
# key: ((op1_name, op_type1),(op2_name, op_type1), (op3_name, op_type1), ...)
141+
# value: target dtype
142+
ops_dtypes = OrderedDict()
143+
for op_type, op_lst in op_type_groups:
144+
ops_dtypes[tuple(op_lst)] = target_dtype
106145
fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[],
107-
initial_op_tuning_cfg=initial_op_tuning_cfg,
108-
op_dtypes=op_dtypes, accumulate=False)
146+
initial_op_tuning_cfg=initial_op_tuning_cfg,
147+
op_dtypes=ops_dtypes, accumulate=False)
109148
op_fallback_acc_impact = OrderedDict()
110149
for op_index, op_tuning_cfg in enumerate(fallback_sampler):
111-
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
150+
op_tuning_cfg['calib_sampling_size'] = -1
151+
yield op_tuning_cfg
152+
acc, _ = self.last_tune_result
153+
op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc
154+
155+
def fallback_in_op_wise(self, tuning_space, fallback_items_name_lst, initial_op_tuning_cfg, target_dtype):
156+
"""Fallback op in op wise.
157+
158+
Args:
159+
tuning_space: tuning space
160+
fallback_items_name_lst: the list of items to be fallback
161+
initial_op_tuning_cfg: initial tuning config
162+
target_dtype: target data type, such as fp32
163+
164+
Yields:
165+
tuning config
166+
"""
167+
op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst)))
168+
fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[],
169+
initial_op_tuning_cfg=initial_op_tuning_cfg,
170+
op_dtypes=op_dtypes, accumulate=False)
171+
op_fallback_acc_impact = OrderedDict()
172+
for op_index, op_tuning_cfg in enumerate(fallback_sampler):
173+
op_tuning_cfg['calib_sampling_size'] = -1
112174
yield op_tuning_cfg
113175
acc, _ = self.last_tune_result
114176
op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc
115177

116178
# do accumulated fallback according to the order in the previous stage
117179
if len(op_fallback_acc_impact) > 0:
118-
ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key],
119-
reverse=self.higher_is_better)
180+
ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key], \
181+
reverse=self.higher_is_better)
120182
op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst)))
121183
logger.info("Start to accumulate fallback to {target_dtype}.")
122-
initial_op_tuning_cfg = deepcopy(op_tuning_cfg)
184+
initial_op_tuning_cfg = copy.deepcopy(op_tuning_cfg)
123185
fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[],
124-
initial_op_tuning_cfg=initial_op_tuning_cfg,
125-
op_dtypes=op_dtypes, accumulate=True)
186+
initial_op_tuning_cfg=initial_op_tuning_cfg,
187+
op_dtypes=op_dtypes, accumulate=True)
126188
for op_tuning_cfg in fallback_sampler:
127-
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
189+
op_tuning_cfg['calib_sampling_size'] = -1
128190
yield op_tuning_cfg
129191

192+
130193
def traverse(self):
131194
"""Traverse the tuning space according to auto-mixed precision strategy."""
132195
if self.config.backend == "ipex":

neural_compressor/strategy/utils/tuning_sampler.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from itertools import product
2121
import copy
2222
from collections import deque, OrderedDict, defaultdict
23-
from typing import List, Dict, Any
23+
from typing import List, Dict, Any, Union, Tuple
2424
from .tuning_space import TuningSpace, pattern_to_internal, pattern_to_path, quant_mode_from_pattern
2525
from .tuning_structs import OpTuningConfig
2626
from ...utils import logger
@@ -382,8 +382,8 @@ class FallbackTuningSampler(TuningSampler):
382382
def __init__(self,
383383
tuning_space: TuningSpace,
384384
tuning_order_lst: List[TuningOrder],
385-
initial_op_tuning_cfg: Dict[tuple, Any],
386-
op_dtypes: Dict[str, str],
385+
initial_op_tuning_cfg: Dict[Tuple, Any],
386+
op_dtypes: Dict[Union[Tuple, Tuple[Tuple]], str],
387387
accumulate: bool,
388388
skip_first: bool = True
389389
):
@@ -414,21 +414,23 @@ def __iter__(self):
414414
# Only support fallback to lower precision.
415415
if not self.accumulate:
416416
new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg)
417-
full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype)
418-
self.op_complete_path[op_name_type] = copy.deepcopy(full_path)
419-
config_args = {}
420-
self._set_dtype(op_name_type, config_args)
421-
internal_pattern = pattern_to_internal(target_dtype)
422-
quant_mode = quant_mode_from_pattern(internal_pattern)
423-
new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1],
424-
quant_mode, self.tuning_space,
425-
kwargs=config_args)
417+
op_name_type_lst = [op_name_type] if len(op_name_type) != 1 and \
418+
isinstance(op_name_type[1], str) else op_name_type
419+
for op_name_type in op_name_type_lst:
420+
full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype)
421+
self.op_complete_path[op_name_type] = copy.deepcopy(full_path)
422+
config_args = {}
423+
self._set_dtype(op_name_type, config_args)
424+
internal_pattern = pattern_to_internal(target_dtype)
425+
quant_mode = quant_mode_from_pattern(internal_pattern)
426+
new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1], quant_mode, \
427+
self.tuning_space, kwargs=config_args)
426428

427-
new_tune_cfg.update({op_name_type: new_op_config})
429+
new_tune_cfg.update({op_name_type: new_op_config})
428430
if self.accumulate and skip_first: # skip the first one
429431
skip_first = False
430432
continue
431-
logger.info(f"fallback {op_name_type} to {target_dtype}")
433+
logger.info(f"fallback {op_name_type_lst} to {target_dtype}")
432434
yield new_tune_cfg # need to skip the first one
433435

434436
class LowerBitsSampler(TuningSampler):

test/mixed_precision/test_mixed_precision.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def test_mixed_precision_with_eval_func(self):
328328
def eval(model):
329329
return 0.5
330330

331-
result = [0., 0.1, 0.102, 0.1006, 0.1005, 0.1004, 0.1002]
331+
result = [0., 0.1, 0.102, 0.1003, 0.1005, 0.1004, 0.1002]
332332

333333
def eval2(model):
334334
del result[0]
@@ -371,6 +371,83 @@ def eval2(model):
371371
output_model = fit(self.tf_model, conf, eval)
372372
self.assertTrue(any([i.op == 'Cast' for i in output_model.graph_def.node]))
373373

374+
375+
def test_mixed_precision_with_quant_level_1(self):
376+
377+
result = [0., 0.1, 0.102]
378+
def eval_func(model):
379+
del result[0]
380+
return result[0]
381+
382+
conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto")
383+
384+
output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
385+
self.assertTrue(any([i.op == 'Cast' for i in output_model.graph_def.node]))
386+
self.assertEqual(conf.inputs, 'input')
387+
self.assertEqual(conf.outputs, 'final')
388+
389+
def test_mixed_precision_with_quant_level_2(self):
390+
391+
result = [0., 1, 0.9, 1.1]
392+
# meet acc if fallback all conv
393+
def eval_func(model):
394+
del result[0]
395+
return result[0]
396+
397+
conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto")
398+
399+
output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
400+
# no cast in output model
401+
self.assertFalse(any([i.op == 'Cast' for i in output_model.graph_def.node]))
402+
403+
def test_mixed_precision_with_quant_level_3(self):
404+
405+
result = [0., 1, 0.9, 0.9, 1.1]
406+
# meet acc if fallback 1 conv
407+
def eval_func(model):
408+
del result[0]
409+
return result[0]
410+
411+
conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto")
412+
413+
output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
414+
# no cast in output model
415+
count_cast = 0
416+
for node in output_model.graph_def.node:
417+
if node.op == "Cast":
418+
count_cast += 1
419+
self.assertEqual(count_cast, 4)
420+
421+
def test_mixed_precision_with_quant_level_4(self):
422+
423+
result = [0., 1, 0.9, 0.9, 1.1]
424+
# meet acc if fallback the second conv
425+
def eval_func(model):
426+
del result[0]
427+
return result[0]
428+
429+
conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level=1)
430+
431+
output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
432+
# no cast in output model
433+
count_cast = 0
434+
for node in output_model.graph_def.node:
435+
if node.op == "Cast":
436+
count_cast += 1
437+
self.assertEqual(count_cast, 4)
438+
439+
def test_mixed_precision_with_quant_level_5(self):
440+
result = [0., 1, 0.9, 0.9, 0.9]
441+
# meet not meet
442+
def eval_func(model):
443+
del result[0]
444+
return result[0]
445+
446+
conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level=0)
447+
448+
output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
449+
self.assertIsNone(output_model)
450+
374451
@unittest.skipIf(PT_VERSION.release < Version("1.11.0").release,
375452
"Please use PyTroch 1.11 or higher version for mixed precision.")
376453
def test_mixed_precision_with_eval_func_pt(self):

0 commit comments

Comments
 (0)