3232 logger = logging .getLogger ()
3333from collections import UserDict , defaultdict
3434
35+ from tqdm import tqdm
36+
37+
38+ def enough_memo_store_scale (device , need_space ):
39+ if device == "cuda" : # pragma: no cover
40+ current_gpu_index = torch .cuda .current_device ()
41+ total_memory = torch .cuda .get_device_properties (current_gpu_index ).total_memory
42+ used_memory = torch .cuda .memory_allocated (current_gpu_index )
43+ free_space = total_memory - used_memory
44+ else :
45+ import psutil
46+
47+ free_space = psutil .virtual_memory ().free
48+ return free_space >= need_space
49+
3550
3651def move_input_to_device (input , device = torch .device ("cpu" )):
3752 if isinstance (input , dict ) or isinstance (input , UserDict ):
@@ -333,6 +348,9 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra
333348 self .weight_clip = True
334349 self .default_alpha = 0.5
335350
351+ self ._save_scale = False
352+ self .weight_scale_dict = {}
353+
336354 def _get_device (self ):
337355 """Get the model device
338356 :return:Model device."""
@@ -562,12 +580,7 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
562580 weight_scales_info = {}
563581 absorb_scales_info = {}
564582 for index , key in enumerate (absorb_to_layer .keys ()):
565- if isinstance (alpha , float ):
566- alpha_tmp = alpha
567- elif isinstance (alpha , dict ):
568- alpha_tmp = alpha [key ]
569- else :
570- alpha_tmp = alpha
583+ alpha_tmp = alpha [key ] if isinstance (alpha , dict ) else alpha
571584 if alpha_tmp < 0 :
572585 scale = torch .ones ((1 ), device = self .device )
573586 else :
@@ -591,13 +604,24 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
591604 self .max_value_info [key ]["absorbed_layer" ] = layer_names
592605 continue
593606
594- scale = cal_scale (input_max , weights , alpha_tmp )
607+ if self ._save_scale :
608+ if key in self .weight_scale_dict and alpha_tmp in self .weight_scale_dict [key ]:
609+ scale = self .weight_scale_dict [key ][alpha_tmp ]
610+ else :
611+ scale = cal_scale (input_max , weights , alpha_tmp )
612+ else :
613+ scale = cal_scale (input_max , weights , alpha_tmp )
614+
595615 absorb_scales_info [key ] = 1.0 / scale
596616 absorb_scales_info [key ][scale == 0 ] = 0
597617 layer_names = absorb_to_layer [key ]
598618 for layer_name in layer_names :
599619 ##self._scale_layer_weight(layer_name, scale)
600620 weight_scales_info [layer_name ] = scale
621+ if self ._save_scale :
622+ if layer_name not in self .weight_scale_dict :
623+ self .weight_scale_dict [layer_name ] = {}
624+ self .weight_scale_dict [layer_name ][alpha_tmp ] = scale
601625 return absorb_scales_info , weight_scales_info
602626
603627 def _adjust_parameters (self , absorb_to_layer , input_maxes , alpha = 0.5 , tuning = False ):
@@ -869,8 +893,9 @@ def _auto_tune_alpha(
869893 logger .info (f"Auto-tuning failed due to no dataloader, using { best_alphas } instead." )
870894 self ._qdq_model_unwrapper_for_auto ()
871895 return best_alphas
896+ bar = tqdm (self .dataloader , total = calib_sample_num , desc = "auto tune alpha" )
872897 try :
873- for input , label in self . dataloader :
898+ for input , label in bar :
874899 loss_alphas = {}
875900 best_alphas_per_module = best_alphas
876901 if isinstance (best_alphas , dict ):
@@ -899,10 +924,12 @@ def _auto_tune_alpha(
899924 self .absorb_to_layer , input_maxes , best_alphas , tuning = True
900925 )
901926 self ._update_scales_for_auto (absorb_input_scales , weight_scales )
927+ # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change
928+ # self.weight_scale_dict = {}
902929 if total_cnt >= calib_sample_num :
903930 break
904931 except :
905- for input in self . dataloader :
932+ for input in bar :
906933 loss_alphas = {}
907934 best_alphas_per_module = best_alphas
908935 if isinstance (best_alphas , dict ):
@@ -932,6 +959,7 @@ def _auto_tune_alpha(
932959 self .absorb_to_layer , input_maxes , best_alphas , tuning = True
933960 )
934961 self ._update_scales_for_auto (absorb_input_scales , weight_scales )
962+ # self.weight_scale_dict = {}
935963 if total_cnt >= calib_sample_num :
936964 break
937965
@@ -1036,6 +1064,18 @@ def transform(
10361064 for d in diff_modules :
10371065 del self .absorb_to_layer [d ]
10381066
1067+ scale_memo_use = 0
1068+ for key in self .absorb_to_layer :
1069+ layer_name = self .absorb_to_layer [key ][0 ]
1070+ input_max = input_maxes_abs [layer_name ]
1071+ scale_memo_use += 4 * input_max .shape [0 ] * len (self .absorb_to_layer [key ])
1072+ if alpha == "auto" :
1073+ alpha_space = (auto_alpha_args ["alpha_max" ] - auto_alpha_args ["alpha_min" ]) / auto_alpha_args [
1074+ "alpha_step"
1075+ ] + 1
1076+ scale_memo_use *= alpha_space
1077+ self ._save_scale = enough_memo_store_scale (self .device , scale_memo_use )
1078+
10391079 if alpha == "auto" :
10401080 self .alpha_per_layer = self ._auto_tune_alpha (
10411081 input_maxes_abs , calib_sample_num = 32 , ** auto_alpha_args
@@ -1047,6 +1087,8 @@ def transform(
10471087 if example_inputs is not None :
10481088 out_pre_sq = model_forward_per_sample (self .model , example_inputs , self .device )
10491089
1090+ if folding :
1091+ self ._save_scale = False
10501092 if self .record_max_info :
10511093 # max_info is recorded in self.max_value_info
10521094 self ._adjust_parameters (self .absorb_to_layer , input_maxes_abs , alpha )
0 commit comments