Skip to content

Commit f3b38b2

Browse files
authored
Enhance the strategy to avoid repeatedly initializing adaptor (#832)
Signed-off-by: Cheng, Zixuan <[email protected]> Signed-off-by: yiliu30 <[email protected]> Signed-off-by: Lv, Liang1 <[email protected]> Signed-off-by: Cheng, Penghui <[email protected]>
1 parent fb8e503 commit f3b38b2

File tree

8 files changed

+218
-79
lines changed

8 files changed

+218
-79
lines changed

neural_compressor/adaptor/torch_utils/hawq_metric.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def remove(self):
5252
class HessianTrace:
5353
"""HessianTrace Class.
5454
55-
Please refer to Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian."
55+
Please refer to Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian."
5656
2020 IEEE international conference on big data (Big data). IEEE, 2020.
57-
Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks."
57+
Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks."
5858
Advances in neural information processing systems 33 (2020): 18518-18529.
5959
https://github.com/openvinotoolkit/nncf/blob/develop/nncf/torch/quantization/hessian_trace.py
6060
"""
@@ -173,7 +173,7 @@ def act_grad_hook(model, grad_input, grad_output):
173173
def _get_enable_act_grad_hook(self, name):
174174
def enable_act_grad_hook(model, inputs, outputs):
175175
input = inputs[0]
176-
if input.requires_grad is False:
176+
if input.requires_grad is False: #
177177
input.requires_grad = True
178178
self.layer_acts[name] = input
179179

@@ -251,13 +251,13 @@ def _sample_rademacher(self, params):
251251
r.masked_fill_(r == 0, -1)
252252
samples.append(r)
253253
return samples
254-
254+
255255
def _sample_rademacher_like_params(self):
256256
def sample(parameter):
257257
r = torch.randint_like(parameter, high=2, device=self.device)
258258
return r.masked_fill_(r == 0, -1)
259259
return [sample(p) for p in self.params]
260-
260+
261261
def _sample_normal_like_params(self):
262262
return [torch.randn(p.size(), device=self.device) for p in self.params]
263263

@@ -391,7 +391,7 @@ def _insert_hook(self, model, target_module_list):
391391
for layer, module in model.named_modules():
392392
for target_module in target_module_list:
393393
# print("layer:",layer)
394-
# print("target_model:",target_module)
394+
# print("target_model:",target_module)
395395
if layer == target_module:
396396
logging.debug("Collect: %s" % (module))
397397
# print("Collect: %s" % (module))
@@ -408,7 +408,7 @@ def _insert_hook_quantize(self, model, target_module_list):
408408
# print("layer:",layer)
409409
length = len("_model.")
410410
new_key = layer[length:]
411-
# print("target_model:",target_module)
411+
# print("target_model:",target_module)
412412
if new_key == target_module:
413413
logging.debug("Collect: %s" % (module))
414414
# print("Collect: %s" % (module))
@@ -521,7 +521,7 @@ def compare_weights(
521521
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
522522
) -> Dict[str, Dict[str, torch.Tensor]]:
523523
r"""Compare the weights of the float module with its corresponding quantized module.
524-
524+
525525
Returns a dict with key corresponding to module names and each entry being
526526
a dictionary with two keys 'float' and 'quantized', containing the float and
527527
quantized weights. This dict can be used to compare and compute the quantization
@@ -608,7 +608,7 @@ def hawq_top(fp32_model, q_model, dataloader, criterion, enable_act):
608608
op_qnt_tensor = weight_quant_loss[key]['quantized'].dequantize()
609609
diff_l2 = (torch.norm(op_float_tensor - op_qnt_tensor, p=2) ** 2)
610610
pertur_lst[key] = diff_l2
611-
611+
612612
if enable_act:
613613
act_to_traces = traces['activation']
614614
for trace_i, pertur_i, act_i in zip(op_to_traces.keys(), pertur_lst.keys(), act_to_traces.keys()):

neural_compressor/contrib/strategy/sigopt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def traverse(self):
194194
195195
This is SigOpt version of traverse -- with additional constraints setting to HPO.
196196
"""
197-
self._eval_baseline()
197+
self._prepare_tuning()
198198

199199
baseline_msg = '[Accuracy: {:.4f}'.format(self.baseline[0]) + \
200200
''.join([', {}: {:.4f}'.format(x,y) for x,y in zip( \

neural_compressor/contrib/strategy/tpe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _configure_hpopt_search_space_and_params(self, search_space):
191191
def traverse(self):
192192
"""Tpe traverse logic."""
193193
logger.info("Start to run tpe strategy.")
194+
self._prepare_tuning()
194195
# prepare log file
195196
trials_file = os.path.join(os.path.dirname(self.history_path), 'tpe_trials.csv')
196197
best_result_file = os.path.join(os.path.dirname(self.history_path), 'tpe_best_result.csv')

neural_compressor/strategy/auto.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,10 @@ def sequential_traverse(self):
7979
eval_dataloader=self.eval_dataloader,
8080
eval_metric=self.eval_metric,
8181
resume=self._resume,
82-
q_hooks=self.q_hooks)
82+
q_hooks=self.q_hooks,
83+
pre_strategy = pre_strategy
84+
)
8385

84-
if pre_strategy:
85-
#TODO add tuning history from the previous stage to current stage.
86-
strategy.baseline = deepcopy(pre_strategy.baseline)
87-
strategy.trials_count = pre_strategy.trials_count
88-
strategy.objectives.baseline = deepcopy(pre_strategy.baseline)
8986
pre_strategy = strategy
9087
strategy.traverse()
9188
self.best_qmodel = strategy.best_qmodel

neural_compressor/strategy/auto_mixed_precision.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ def next_tune_cfg(self):
128128

129129
def traverse(self):
130130
"""Traverse the tuning space according to auto-mixed precision strategy."""
131-
# get fp32 model baseline
132-
self._eval_baseline()
131+
self._prepare_tuning()
133132

134133
for op_tuning_cfg in self.next_tune_cfg():
135134
# add tune_cfg here as quantize use tune_cfg

0 commit comments

Comments
 (0)