Skip to content

Commit 173c188

Browse files
authored
increase sq auto alpha running speed (#1399)
Signed-off-by: Guo, Heng <[email protected]>
1 parent fcbac41 commit 173c188

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

neural_compressor/adaptor/torch_utils/smooth_quant.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@
3232
logger = logging.getLogger()
3333
from collections import UserDict, defaultdict
3434

35+
from tqdm import tqdm
36+
37+
38+
def enough_memo_store_scale(device, need_space):
39+
if device == "cuda": # pragma: no cover
40+
current_gpu_index = torch.cuda.current_device()
41+
total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory
42+
used_memory = torch.cuda.memory_allocated(current_gpu_index)
43+
free_space = total_memory - used_memory
44+
else:
45+
import psutil
46+
47+
free_space = psutil.virtual_memory().free
48+
return free_space >= need_space
49+
3550

3651
def move_input_to_device(input, device=torch.device("cpu")):
3752
if isinstance(input, dict) or isinstance(input, UserDict):
@@ -333,6 +348,9 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra
333348
self.weight_clip = True
334349
self.default_alpha = 0.5
335350

351+
self._save_scale = False
352+
self.weight_scale_dict = {}
353+
336354
def _get_device(self):
337355
"""Get the model device
338356
:return:Model device."""
@@ -562,12 +580,7 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
562580
weight_scales_info = {}
563581
absorb_scales_info = {}
564582
for index, key in enumerate(absorb_to_layer.keys()):
565-
if isinstance(alpha, float):
566-
alpha_tmp = alpha
567-
elif isinstance(alpha, dict):
568-
alpha_tmp = alpha[key]
569-
else:
570-
alpha_tmp = alpha
583+
alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha
571584
if alpha_tmp < 0:
572585
scale = torch.ones((1), device=self.device)
573586
else:
@@ -591,13 +604,24 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
591604
self.max_value_info[key]["absorbed_layer"] = layer_names
592605
continue
593606

594-
scale = cal_scale(input_max, weights, alpha_tmp)
607+
if self._save_scale:
608+
if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]:
609+
scale = self.weight_scale_dict[key][alpha_tmp]
610+
else:
611+
scale = cal_scale(input_max, weights, alpha_tmp)
612+
else:
613+
scale = cal_scale(input_max, weights, alpha_tmp)
614+
595615
absorb_scales_info[key] = 1.0 / scale
596616
absorb_scales_info[key][scale == 0] = 0
597617
layer_names = absorb_to_layer[key]
598618
for layer_name in layer_names:
599619
##self._scale_layer_weight(layer_name, scale)
600620
weight_scales_info[layer_name] = scale
621+
if self._save_scale:
622+
if layer_name not in self.weight_scale_dict:
623+
self.weight_scale_dict[layer_name] = {}
624+
self.weight_scale_dict[layer_name][alpha_tmp] = scale
601625
return absorb_scales_info, weight_scales_info
602626

603627
def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
@@ -869,8 +893,9 @@ def _auto_tune_alpha(
869893
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
870894
self._qdq_model_unwrapper_for_auto()
871895
return best_alphas
896+
bar = tqdm(self.dataloader, total=calib_sample_num, desc="auto tune alpha")
872897
try:
873-
for input, label in self.dataloader:
898+
for input, label in bar:
874899
loss_alphas = {}
875900
best_alphas_per_module = best_alphas
876901
if isinstance(best_alphas, dict):
@@ -899,10 +924,12 @@ def _auto_tune_alpha(
899924
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
900925
)
901926
self._update_scales_for_auto(absorb_input_scales, weight_scales)
927+
# does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change
928+
# self.weight_scale_dict = {}
902929
if total_cnt >= calib_sample_num:
903930
break
904931
except:
905-
for input in self.dataloader:
932+
for input in bar:
906933
loss_alphas = {}
907934
best_alphas_per_module = best_alphas
908935
if isinstance(best_alphas, dict):
@@ -932,6 +959,7 @@ def _auto_tune_alpha(
932959
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
933960
)
934961
self._update_scales_for_auto(absorb_input_scales, weight_scales)
962+
# self.weight_scale_dict = {}
935963
if total_cnt >= calib_sample_num:
936964
break
937965

@@ -1036,6 +1064,18 @@ def transform(
10361064
for d in diff_modules:
10371065
del self.absorb_to_layer[d]
10381066

1067+
scale_memo_use = 0
1068+
for key in self.absorb_to_layer:
1069+
layer_name = self.absorb_to_layer[key][0]
1070+
input_max = input_maxes_abs[layer_name]
1071+
scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key])
1072+
if alpha == "auto":
1073+
alpha_space = (auto_alpha_args["alpha_max"] - auto_alpha_args["alpha_min"]) / auto_alpha_args[
1074+
"alpha_step"
1075+
] + 1
1076+
scale_memo_use *= alpha_space
1077+
self._save_scale = enough_memo_store_scale(self.device, scale_memo_use)
1078+
10391079
if alpha == "auto":
10401080
self.alpha_per_layer = self._auto_tune_alpha(
10411081
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
@@ -1047,6 +1087,8 @@ def transform(
10471087
if example_inputs is not None:
10481088
out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device)
10491089

1090+
if folding:
1091+
self._save_scale = False
10501092
if self.record_max_info:
10511093
# max_info is recorded in self.max_value_info
10521094
self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha)

0 commit comments

Comments
 (0)