2525
2626from  .utils .tuning_sampler  import  OpTypeWiseTuningSampler , FallbackTuningSampler , ModelWiseTuningSampler 
2727from  .utils .tuning_structs  import  OpTuningConfig 
28- from  .utils .tuning_space  import  TUNING_ITEMS_LST 
28+ from  .utils .constant  import  TUNING_ITEMS_LST 
2929
3030@strategy_registry  
3131class  BasicTuneStrategy (TuneStrategy ):
@@ -45,13 +45,13 @@ def next_tune_cfg(self):
4545        tuning_space  =  self .tuning_space 
4646        calib_sampling_size_lst  =  tuning_space .root_item .get_option_by_name ('calib_sampling_size' ).options 
4747        for  calib_sampling_size  in  calib_sampling_size_lst :
48-             # Initialize the tuning config for each op according to the quantization approach   
48+             # Initialize the tuning config for each op according to the quantization approach.  
4949            op_item_dtype_dict , quant_mode_wise_items , initial_op_tuning_cfg  =  self .initial_tuning_cfg ()
5050            # Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight) 
5151            early_stop_tuning  =  False 
5252            stage1_cnt  =  0 
53-             quant_ops  =  quant_mode_wise_items [ 'static' ]  if   'static'   in   quant_mode_wise_items   else  [] 
54-             quant_ops  +=  quant_mode_wise_items [ 'dynamic' ]  if   'dynamic'   in   quant_mode_wise_items   else  [] 
53+             quant_ops  =  quant_mode_wise_items . get ( 'static' , []) 
54+             quant_ops  +=  quant_mode_wise_items . get ( 'dynamic' , []) 
5555            stage1_max  =  1e9   # TODO set a more appropriate value 
5656            op_wise_tuning_sampler  =  OpTypeWiseTuningSampler (tuning_space , [], [], 
5757                                                             op_item_dtype_dict , initial_op_tuning_cfg )
@@ -120,22 +120,25 @@ def _initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg:OpTuningConfig)
120120        op_state  =  op_static_cfg .get_state ()
121121        op_name  =  op_static_cfg .op_name 
122122        op_type  =  op_static_cfg .op_type 
123+         op_name_type  =  (op_name , op_type )
123124        op_quant_mode  =  'dynamic' 
124125        tuning_space  =  self .tuning_space 
125126        dynamic_state  =  {}
126127        for  att  in  ['weight' , 'activation' ]:
127-             if  att  not  in op_state :
128-                 continue 
129-             for  item_name , item_val  in  op_state [att ].items ():
130-                 att_item  =  (att , item_name )
131-                 if  att_item  not  in TUNING_ITEMS_LST :
132-                     continue 
133-                 if  tuning_space .query_item_option ((op_name , op_type ), op_quant_mode , att_item , item_val ):
134-                     dynamic_state [att_item ] =  item_val 
128+             if  att  not  in op_state : continue 
129+             # Add dtype 
130+             full_path  =  self .tuning_space .get_op_default_path_by_pattern (op_name_type , op_quant_mode )
131+             dynamic_state [att  +  '_dtype' ] =  self .tuning_space .ops_data_type [op_name_type ][full_path [att ]]
132+             for  method_name , method_val  in  op_state [att ].items ():
133+                 att_and_method_name  =  (att , method_name )
134+                 if  att_and_method_name  not  in TUNING_ITEMS_LST : continue 
135+                 if  tuning_space .query_item_option (op_name_type , full_path [att ], att_and_method_name , method_val ):
136+                     dynamic_state [att_and_method_name ] =  method_val 
135137                else :
136-                     quant_mode_item  =  tuning_space .query_quant_mode_item ((op_name , op_type ), op_quant_mode )
137-                     tuning_item  =  quant_mode_item .get_option_by_name (att_item )
138-                     dynamic_state [att_item ] =  tuning_item .options [0 ] if  tuning_item  else  None 
138+                     quant_mode_item  =  tuning_space .get_item_by_path ((op_name_type , * full_path [att ]))
139+                     if  quant_mode_item  and  quant_mode_item .get_option_by_name (att_and_method_name ):
140+                         tuning_item  =  quant_mode_item .get_option_by_name (att_and_method_name )
141+                         dynamic_state [att_and_method_name ] =  tuning_item .options [0 ] if  tuning_item  else  None 
139142        return  OpTuningConfig (op_name , op_type , op_quant_mode , tuning_space , kwargs = dynamic_state )
140143
141144
0 commit comments