@@ -501,44 +501,43 @@ def _reshape_scale_for_input(self, layer, scale):
501501
502502 return scale
503503
504- def _scale_layer_weight (self , layer_name , scale ): ##input channel
504+ def _scale_layer_weight (self , layer_name , scale , alpha = 0.5 , input_minmax = None ): ##input channel
505505 """
506506 Scale the layer weights at input channel, depthwise conv output channel
507507 :param layer_name: The layer name
508508 :param scale: The scale to be multiplied
509+ :param alpha: alpha for SQLinearWrapper
510+ :param input_minmax: input_minmax for SQLinearWrapper
509511 :return:
510512 """
511513 layer = get_module (self .model , layer_name )
512- if layer .__class__ .__name__ == "SQLinearWrapper" :
513- return scale # weigth update is done in SQLinearWrapper initialization
514- scale = self ._reshape_scale_for_weight (layer , scale )
515- layer .weight = torch .nn .Parameter (layer .weight * scale )
514+ if self .insert_mul :
515+ from .model_wrapper import SQLinearWrapper
516+ layer = get_module (self .model , layer_name )
517+ if isinstance (layer , SQLinearWrapper ):
518+ layer ._recover_sq_linear ()
519+ set_module (self .model , layer_name , layer .sq_linear ) ##recover
520+ else :
521+ new_module = SQLinearWrapper (layer , 1.0 / scale , input_minmax , alpha )
522+ set_module (self .model , layer_name , new_module )
523+ elif self .allow_absorb :
524+ scale = self ._reshape_scale_for_weight (layer , scale )
525+ layer .weight = torch .nn .Parameter (layer .weight * scale )
516526 return scale
517527
518- def _absorb_scales (self , layer_name , scale , alpha = 0.5 ): ##output channel
528+ def _absorb_scales (self , layer_name , scale ): ##output channel
519529 """
520530 Absorb the scale to the layer at output channel
521531 :param layer_name: The module name
522532 :param scale: The scale to be absorbed
523533 :param alpha_key: The alpha passed to SQLinearWrapper
524534 :return:
525535 """
526- layer = get_module (self .model , layer_name )
527- if self .insert_mul :
528- if layer .__class__ .__name__ == "SQLinearWrapper" :
529- layer ._recover_sq_linear ()
530- set_module (self .model , layer_name , layer .sq_linear ) ##recover
531- else :
532- from .model_wrapper import SQLinearWrapper
533- input_minmax = [self .input_mins [layer_name ], self .input_maxes [layer_name ]]
534- new_module = SQLinearWrapper (layer , scale , input_minmax , alpha )
535- set_module (self .model , layer_name , new_module )
536- return
537-
538- if not self .allow_absorb :
539- return ## change the code style due to too many if/else statements in the following
536+ if self .insert_mul or not self .allow_absorb :
537+ return # absorb is updated in SQLinearWrapper in def _scale_layer_weight
540538
541539 ##if self.allow absorb
540+ layer = get_module (self .model , layer_name )
542541 if layer .__class__ .__name__ == 'WrapperLayer' :
543542 layer = layer .orig_layer
544543 if isinstance (layer , torch .nn .BatchNorm2d ) or isinstance (layer , torch .nn .GroupNorm ) or \
@@ -650,7 +649,9 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
650649 :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict
651650 :return:
652651 """
653- absorb_scales_info , weight_scales_info = self ._cal_scales (absorb_to_layer , input_maxes , alpha , tuning )
652+ absorb_scales_info , weight_scales_info = self ._cal_scales (
653+ absorb_to_layer , input_maxes , alpha , tuning
654+ )
654655 if not absorb_scales_info or not weight_scales_info :
655656 return weight_scales_info , absorb_scales_info
656657 for index , key in enumerate (absorb_to_layer .keys ()):
@@ -659,10 +660,13 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
659660 elif isinstance (alpha , dict ):
660661 alpha_tmp = alpha [key ]
661662 absorb_scale = absorb_scales_info [key ]
662- self ._absorb_scales (key , absorb_scale , alpha_tmp )
663+ self ._absorb_scales (key , absorb_scale )
663664 layer_names = absorb_to_layer [key ]
664665 for layer_name in layer_names :
665- self ._scale_layer_weight (layer_name , weight_scales_info [layer_name ])
666+ input_minmax = [self .input_mins [layer_names [0 ]], self .input_maxes [layer_names [0 ]]]
667+ self ._scale_layer_weight (
668+ layer_name , weight_scales_info [layer_name ], alpha_tmp , input_minmax
669+ )
666670 return weight_scales_info , absorb_scales_info
667671
668672 def _check_need_calibration (self , alpha , percentile , op_types ,
@@ -1110,10 +1114,14 @@ def _get_example_input(self):
11101114 if self .dataloader == None and self .example_inputs == None :
11111115 return None
11121116 if self .example_inputs is None :
1113- ##assert self.dataloader, "Please provide dataloader or example_inputs"
1114- for idx , input in enumerate (self .dataloader ):
1115- self .example_inputs = input
1116- break
1117+ try :
1118+ for idx , (input , label ) in enumerate (self .dataloader ):
1119+ self .example_inputs = input
1120+ break
1121+ except :
1122+ for idx , input in enumerate (self .dataloader ):
1123+ self .example_inputs = input
1124+ break
11171125
11181126 return self .example_inputs
11191127
0 commit comments