Skip to content

Commit 1e9c12b

Browse files
authored
fixed the key comparison in the Bayesian strategy (#1484)
1 parent 43943d8 commit 1e9c12b

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

neural_compressor/strategy/bayesian.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sklearn.gaussian_process import GaussianProcessRegressor
2424

2525
from collections import OrderedDict
26+
from copy import deepcopy
2627

2728
from ..utils import logger
2829
from .strategy import strategy_registry, TuneStrategy
@@ -104,7 +105,10 @@ def params_to_tune_configs(self, params):
104105
op_tuning_cfg[op_name_type] = configs[0]
105106
else:
106107
op_tuning_cfg[op_name_type] = configs[min(len(configs) - 1, int(params[op_name_type[0]]))]
107-
calib_sampling_size = calib_sampling_size_lst[min(len(configs) - 1, int(params['calib_sampling_size']))]
108+
if len(calib_sampling_size_lst) > 1:
109+
calib_sampling_size = calib_sampling_size_lst[min(len(configs) - 1, int(params['calib_sampling_size']))]
110+
else:
111+
calib_sampling_size = calib_sampling_size_lst[0]
108112
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
109113
return op_tuning_cfg
110114

@@ -115,7 +119,6 @@ def next_tune_cfg(self):
115119
"""
116120
params = None
117121
pbounds = {}
118-
from copy import deepcopy
119122
tuning_space = self.tuning_space
120123
calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options
121124
op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg()
@@ -126,7 +129,8 @@ def next_tune_cfg(self):
126129
for op_name_type, configs in self.op_configs.items():
127130
if len(configs) > 1:
128131
pbounds[op_name_type[0]] = (0, len(configs))
129-
pbounds['calib_sampling_size'] = (0, len(calib_sampling_size_lst))
132+
if len(calib_sampling_size_lst) > 1:
133+
pbounds['calib_sampling_size'] = (0, len(calib_sampling_size_lst))
130134
if len(pbounds) == 0:
131135
yield self.params_to_tune_configs(params)
132136
return
@@ -225,10 +229,11 @@ def __init__(self, pbounds, random_seed=9527):
225229
"""
226230
self.random_seed = random_seed
227231
# Get the name of the parameters
228-
self._keys = sorted(pbounds)
232+
names = list(pbounds.keys())
233+
self._keys = deepcopy(names)
229234
# Create an array with parameters bounds
230235
self._bounds = np.array(
231-
[item[1] for item in sorted(pbounds.items(), key=lambda x: x[0])],
236+
[pbounds[name] for name in names],
232237
dtype=np.float
233238
)
234239

@@ -275,7 +280,7 @@ def params_to_array(self, params):
275280
assert set(params) == set(self.keys)
276281
except AssertionError:
277282
raise ValueError(
278-
"Parameters' keys ({}) do ".format(sorted(params)) +
283+
"Parameters' keys ({}) do ".format(list(params.keys())) +
279284
"not match the expected set of keys ({}).".format(self.keys)
280285
)
281286
return np.asarray([params[key] for key in self.keys])

0 commit comments

Comments
 (0)