Skip to content

Commit d0059c4

Browse files
authored
Extend the strategy capability for adding the new data type (#555)
* refined the tuning space Signed-off-by: yiliu30 <[email protected]> * fixed the merge with user cfg Signed-off-by: yiliu30 <[email protected]> * parse tuning space Signed-off-by: yiliu30 <[email protected]> * refined the tuning space Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * refine the logical Signed-off-by: yiliu30 <[email protected]> * fixed the pylint error Signed-off-by: yiliu30 <[email protected]> * fixed the typo Signed-off-by: yiliu30 <[email protected]> * fix typo Signed-off-by: yiliu30 <[email protected]> * fixed the merge Signed-off-by: yiliu30 <[email protected]> * fixed the auto quant Signed-off-by: yiliu30 <[email protected]> * fixed quant_mode error Signed-off-by: yiliu30 <[email protected]> * revert some change Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * add ut for int4 Signed-off-by: yiliu30 <[email protected]> * fixed the parse order Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
1 parent 750dff7 commit d0059c4

File tree

12 files changed

+996
-368
lines changed

12 files changed

+996
-368
lines changed

neural_compressor/strategy/basic.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler
2727
from .utils.tuning_structs import OpTuningConfig
28-
from .utils.tuning_space import TUNING_ITEMS_LST
28+
from .utils.constant import TUNING_ITEMS_LST
2929

3030
@strategy_registry
3131
class BasicTuneStrategy(TuneStrategy):
@@ -45,13 +45,13 @@ def next_tune_cfg(self):
4545
tuning_space = self.tuning_space
4646
calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options
4747
for calib_sampling_size in calib_sampling_size_lst:
48-
# Initialize the tuning config for each op according to the quantization approach
48+
# Initialize the tuning config for each op according to the quantization approach.
4949
op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg()
5050
# Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight)
5151
early_stop_tuning = False
5252
stage1_cnt = 0
53-
quant_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else []
54-
quant_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else []
53+
quant_ops = quant_mode_wise_items.get('static', [])
54+
quant_ops += quant_mode_wise_items.get('dynamic', [])
5555
stage1_max = 1e9 # TODO set a more appropriate value
5656
op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [],
5757
op_item_dtype_dict, initial_op_tuning_cfg)
@@ -120,22 +120,25 @@ def _initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg:OpTuningConfig)
120120
op_state = op_static_cfg.get_state()
121121
op_name = op_static_cfg.op_name
122122
op_type = op_static_cfg.op_type
123+
op_name_type = (op_name, op_type)
123124
op_quant_mode = 'dynamic'
124125
tuning_space = self.tuning_space
125126
dynamic_state = {}
126127
for att in ['weight', 'activation']:
127-
if att not in op_state:
128-
continue
129-
for item_name, item_val in op_state[att].items():
130-
att_item = (att, item_name)
131-
if att_item not in TUNING_ITEMS_LST:
132-
continue
133-
if tuning_space.query_item_option((op_name, op_type), op_quant_mode, att_item, item_val):
134-
dynamic_state[att_item] = item_val
128+
if att not in op_state: continue
129+
# Add dtype
130+
full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, op_quant_mode)
131+
dynamic_state[att + '_dtype'] = self.tuning_space.ops_data_type[op_name_type][full_path[att]]
132+
for method_name, method_val in op_state[att].items():
133+
att_and_method_name = (att, method_name)
134+
if att_and_method_name not in TUNING_ITEMS_LST: continue
135+
if tuning_space.query_item_option(op_name_type, full_path[att], att_and_method_name, method_val):
136+
dynamic_state[att_and_method_name] = method_val
135137
else:
136-
quant_mode_item = tuning_space.query_quant_mode_item((op_name, op_type), op_quant_mode)
137-
tuning_item = quant_mode_item.get_option_by_name(att_item)
138-
dynamic_state[att_item] = tuning_item.options[0] if tuning_item else None
138+
quant_mode_item = tuning_space.get_item_by_path((op_name_type, *full_path[att]))
139+
if quant_mode_item and quant_mode_item.get_option_by_name(att_and_method_name):
140+
tuning_item = quant_mode_item.get_option_by_name(att_and_method_name)
141+
dynamic_state[att_and_method_name] = tuning_item.options[0] if tuning_item else None
139142
return OpTuningConfig(op_name, op_type, op_quant_mode, tuning_space, kwargs=dynamic_state)
140143

141144

neural_compressor/strategy/conservative.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def next_tune_cfg(self):
7575
tmp_tune_cfg = deepcopy(tune_cfg)
7676
for item, quant_mode in items_lst:
7777
op_info = item.name
78-
op_config = tuning_space.set_deafult_config(op_info, quant_mode)
78+
op_config = tuning_space.get_default_config(op_info, quant_mode)
7979
tmp_tune_cfg[op_info] = op_config
8080
yield tmp_tune_cfg
8181
if self.acc_meet_flag:
@@ -87,7 +87,7 @@ def next_tune_cfg(self):
8787
logger.info(f"*** Try to convert {op_type} op into {dtype} one by one.")
8888
for item, quant_mode in items_lst:
8989
op_info = item.name
90-
op_config = tuning_space.set_deafult_config(op_info, quant_mode)
90+
op_config = tuning_space.get_default_config(op_info, quant_mode)
9191
tmp_tune_cfg[op_info] = op_config
9292
yield tmp_tune_cfg
9393
if self.acc_meet_flag:
@@ -358,9 +358,9 @@ def _initialize_tune_cfg(self):
358358
for op_info in tmp_non_fp32_ops:
359359
non_fp32_ops_dtype[op_info] = quant_mode
360360
for op_info in fp32_ops:
361-
initial_tuning_cfg[op_info] = tuning_space.set_deafult_config(op_info, "fp32")
361+
initial_tuning_cfg[op_info] = tuning_space.get_default_config(op_info, "fp32")
362362
for op_info, quant_mode in non_fp32_ops_dtype.items():
363-
initial_tuning_cfg[op_info] = tuning_space.set_deafult_config(op_info, quant_mode)
363+
initial_tuning_cfg[op_info] = tuning_space.get_default_config(op_info, quant_mode)
364364
return initial_tuning_cfg
365365

366366
def _quant_items_pool(self, op_type_priority: List[str]) -> OrderedDict[

neural_compressor/strategy/hawq_v2.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler
2626
from .utils.tuning_structs import OpTuningConfig
27-
from .utils.tuning_space import TUNING_ITEMS_LST
27+
from .utils.constant import TUNING_ITEMS_LST
2828
from ..utils import logger
2929

3030
@strategy_registry
@@ -51,8 +51,8 @@ def next_tune_cfg(self):
5151
# Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight)
5252
early_stop_tuning = True
5353
stage1_cnt = 0
54-
quant_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else []
55-
quant_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else []
54+
quant_ops = quant_mode_wise_items.get('static', [])
55+
quant_ops += quant_mode_wise_items.get('dynamic', [])
5656
stage1_max = 1 # TODO set a more appropriate value
5757
op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [],
5858
op_item_dtype_dict, initial_op_tuning_cfg)
@@ -110,24 +110,3 @@ def next_tune_cfg(self):
110110
op_tuning_cfg['calib_sampling_size'] = calib_size
111111
yield op_tuning_cfg
112112

113-
def _initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg: OpTuningConfig):
114-
op_state = op_static_cfg.get_state()
115-
op_name = op_static_cfg.op_name
116-
op_type = op_static_cfg.op_type
117-
op_quant_mode = 'dynamic'
118-
tuning_space = self.tuning_space
119-
dynamic_state = {}
120-
for att in ['weight', 'activation']:
121-
if att not in op_state:
122-
continue
123-
for item_name, item_val in op_state[att].items():
124-
att_item = (att, item_name)
125-
if att_item not in TUNING_ITEMS_LST:
126-
continue
127-
if tuning_space.query_item_option((op_name, op_type), op_quant_mode, att_item, item_val):
128-
dynamic_state[att_item] = item_val
129-
else:
130-
quant_mode_item = tuning_space.query_quant_mode_item((op_name, op_type), op_quant_mode)
131-
tuning_item = quant_mode_item.get_option_by_name(att_item)
132-
dynamic_state[att_item] = tuning_item.options[0] if tuning_item else None
133-
return OpTuningConfig(op_name, op_type, op_quant_mode, tuning_space, kwargs=dynamic_state)

neural_compressor/strategy/strategy.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _remove_redundant_qmodel(self):
291291
self.best_qmodel = None
292292

293293
def _can_create_eval_func_from_cfg(self):
294-
"""Determines whether an eval function can be created from cfg.
294+
"""Determine whether an eval function can be created from cfg.
295295
296296
Returns:
297297
Returns True if the eval func can be created from config, False otherwise.
@@ -432,20 +432,24 @@ def initial_tuning_cfg(self):
432432
quant_mode_wise_items (OrderedDict): key is quant_mode/precision; value is item list.
433433
initial_op_tuning_cfg (OrderedDict): key is (op_name, op_type); value is the initialized tuning config.
434434
"""
435+
from .utils.constant import auto_query_order, static_query_order, dynamic_query_order
436+
from .utils.tuning_space import initial_tuning_cfg_with_quant_mode
435437
if self.cfg.quantization.approach == 'post_training_auto_quant':
436-
query_order = ['static', 'dynamic', 'bf16', 'fp32']
438+
query_order = auto_query_order
437439
elif self.cfg.quantization.approach == 'post_training_dynamic_quant':
438-
query_order = ['dynamic', 'bf16', 'fp32']
440+
query_order = dynamic_query_order
439441
elif self.cfg.quantization.approach == 'post_training_static_quant':
440-
query_order = ['static', 'bf16', 'fp32']
442+
query_order = static_query_order
441443
elif self.cfg.quantization.approach == 'quant_aware_training':
442-
query_order = ['static', 'dynamic', 'bf16', 'fp32']
444+
logger.info("!!! Currently, the qat tuning is not supported by strategy.")
445+
query_order = auto_query_order
443446

444-
quant_mode_wise_items = OrderedDict()
447+
quant_mode_wise_items = OrderedDict() # mode, op_item_lst
445448
pre_items = set()
449+
# Collect op items supported the specified mode.
446450
for quant_mode in query_order:
447451
items = self.tuning_space.query_items_by_quant_mode(quant_mode)
448-
filtered_items = [item for item in items if item not in pre_items]
452+
filtered_items = list(filter(lambda item: item not in pre_items, items))
449453
pre_items = pre_items.union(set(items))
450454
quant_mode_wise_items[quant_mode] = filtered_items
451455

@@ -456,11 +460,12 @@ def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict):
456460
op_item_dtype_dict = OrderedDict()
457461
for quant_mode, quant_mode_items in quant_mode_wise_items.items():
458462
initial_op_quant_mode(quant_mode_items, quant_mode, op_item_dtype_dict)
459-
463+
460464
initial_op_tuning_cfg = {}
461-
for op_name_dtype, quant_mode in op_item_dtype_dict.items():
462-
initial_op_tuning_cfg[op_name_dtype] = OpTuningConfig(op_name_dtype[0], op_name_dtype[1],
463-
quant_mode, self.tuning_space)
465+
for op_name_type, quant_mode in op_item_dtype_dict.items():
466+
initial_op_tuning_cfg[op_name_type] = initial_tuning_cfg_with_quant_mode(op_name_type,
467+
quant_mode,
468+
self.tuning_space)
464469
return op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg
465470

466471
def show_baseline_info(self):
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""Strategy constant."""
19+
20+
PRECISION_SET = {'bf16', 'fp16' , 'fp32',}
21+
QUANT_MODE_SET = {'static', 'dynamic'}
22+
QUNAT_BIT_SET = {'int8', 'uint8', 'int4', 'uint4'}
23+
24+
TUNING_ITEMS_LST = [('activation','scheme'), ('activation','algorithm'), ('activation','granularity'),
25+
('weight','scheme'), ('weight','algorithm'), ('weight','granularity'), 'sampling_size']
26+
27+
PRECISION_SET_V2_0 = {'fp32', 'bf16'}
28+
29+
auto_query_order = ['static', 'dynamic', 'bf16', 'fp16', 'fp32']
30+
static_query_order = ['static', 'bf16', 'fp16', 'fp32']
31+
dynamic_query_order = ['dynamic', 'bf16', 'fp16', 'fp32']

0 commit comments

Comments
 (0)