1616import  json 
1717import  os 
1818import  re 
19+ from  collections  import  OrderedDict 
1920from  typing  import  Dict , List , Union 
2021
2122import  torch 
6667def  cfg_to_qconfig (tune_cfg , cfgs , op_infos_from_cfgs , output_tensor_id_op_name ):  # pragma: no cover 
6768    assert  cfgs  is  not None , "No configure for IPEX int8 model..." 
6869    op_infos  =  copy .deepcopy (op_infos_from_cfgs )
69-     cfgs  =  check_cfg_and_qconfig (tune_cfg ["op" ], cfgs , op_infos , output_tensor_id_op_name )
70+     cfgs ,  user_cfg  =  check_cfg_and_qconfig (tune_cfg ["op" ], cfgs , op_infos , output_tensor_id_op_name )
7071    with  open (ipex_config_path , "w" ) as  write_f :
7172        json .dump (cfgs , write_f , indent = 4 )
73+     return  user_cfg 
7274
7375
7476def  check_cfg_and_qconfig (user_cfg , cfgs , op_infos_from_cfgs , output_tensor_ids_op_name ):  # pragma: no cover 
@@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
8385    Returns: 
8486        cfgs (dict): updated configs. 
8587    """ 
88+     tmp_user_cfg  =  OrderedDict ()
89+     for  op  in  user_cfg :  # map ipex op_name to pt op_name 
90+         for  i , op_name  in  enumerate (op ):
91+             for  ops , _  in  op_infos_from_cfgs .items ():
92+                 if  "fqn"  in  op_infos_from_cfgs [ops ].keys () and  op_infos_from_cfgs [ops ]["fqn" ] ==  op_name :
93+                     ori_op  =  (tuple (ops ), unify_op_type_mapping_ipex [op_infos_from_cfgs [ops ]["op_type" ]])
94+                     tmp_user_cfg [((ori_op [0 ],), ori_op [1 ])] =  user_cfg [op ]
95+                     break 
96+     user_cfg  =  tmp_user_cfg 
8697    for  op_name  in  user_cfg :
8798        inc_op_cfg  =  user_cfg [op_name ]
8899        for  i , name  in  enumerate (op_name [0 ]):
@@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
142153                        else :
143154                            pass 
144155            cfgs [name [0 ]][name [1 ]][name [2 ]] =  ipex_op_cfg 
145-     return  cfgs 
156+     return  cfgs ,  user_cfg 
146157
147158
148159def  generate_activation_observer (scheme , algorithm , smooth_quant = False , smooth_quant_enable = False ):  # pragma: no cover 
@@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs):  # pragma: no cover
212223        cfgs (dict): dict of configuration 
213224    """ 
214225    quantizable_ops  =  []
226+     op_name_info  =  []
215227    # group ops by position for transform-based model 
216228    detector  =  TransformerBasedModelBlockPatternDetector (model )
217229    detect_result  =  detector .detect_block ()
@@ -277,17 +289,30 @@ def get_quantizable_ops_recursively(model, example_inputs):  # pragma: no cover
277289                if  ipex_op_type  in  unify_op_type_mapping_ipex :
278290                    quantizable_ops .append ((tuple (name ), unify_op_type_mapping_ipex [ipex_op_type ]))
279291                    map_op_name_to_fqn [(tuple (name ), ipex_op_type )] =  module_fqn 
292+                     if  "class"  in  ipex_op_type :  # "<class 'torch.nn.modules.activation.ReLU'>" 
293+                         op_type  =  ipex_op_type .split ("'" )[1 ]
294+                         op_name_info .append ((module_fqn , eval (op_type )))
295+                     elif  "method"  in  ipex_op_type :  # "<method 'add' of 'torch._C._TensorBase' objects>" 
296+                         method  =  ipex_op_type .split ("'" )[1 ]
297+                         op_type  =  getattr (
298+                             torch ._C ._TensorBase  if  ipex_ver .release  <  Version ("2.2" ) else  torch ._C .TensorBase , method 
299+                         )
300+                         op_name_info .append ((module_fqn , op_type ))
301+                     else :
302+                         op_name_info .append ((module_fqn , op_type ))
280303                else :
281304                    re_flag  =  False 
282305                    for  pattern , unify_op_type  in  unify_op_type_mapping_ipex ["re" ].items ():
283306                        if  re .match (pattern , ipex_op_type ):
284307                            re_flag  =  True 
285308                            quantizable_ops .append ((tuple (name ), unify_op_type ))
286309                            map_op_name_to_fqn [(tuple (name ), unify_op_type )] =  module_fqn 
310+                             op_name_info .append ((module_fqn , ipex_op_type ))
287311                            break 
288312                    if  not  re_flag :
289313                        quantizable_ops .append ((tuple (name ), ipex_op_type ))
290314                        map_op_name_to_fqn [(tuple (name ), ipex_op_type )] =  module_fqn 
315+                         op_name_info .append ((module_fqn , ipex_op_type ))
291316            else :
292317                op_type  =  "" 
293318                for  op_name  in  name :
@@ -302,14 +327,15 @@ def get_quantizable_ops_recursively(model, example_inputs):  # pragma: no cover
302327                _op_cfg_id  =  name [0 ][2 ]
303328                module_fqn  =  cfgs [_module_key ]["q_op_infos" ][_op_cfg_id ]["fqn" ]
304329                map_op_name_to_fqn [(tuple (name ), op_type )] =  module_fqn 
330+                 op_name_info .append ((module_fqn , op_type ))
305331
306332    logger .debug ("Map op name to fqn: " )
307333    logger .debug (map_op_name_to_fqn )
308334    logger .info ("Attention Blocks : " )
309335    logger .info (attention_block )
310336    logger .info ("FFN Blocks : " )
311337    logger .info (ffn_blocks )
312-     return  quantizable_ops , cfgs , op_infos_from_cfgs , output_tensor_id_op_name 
338+     return  quantizable_ops , cfgs , op_infos_from_cfgs , output_tensor_id_op_name ,  op_name_info 
313339
314340
315341def  simple_inference (q_model , example_inputs , iterations = 1 ):
@@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1):
323349            q_model (example_inputs )
324350
325351
326- def  dump_model_op_stats (tune_cfg ):
352+ def  dump_model_op_stats (user_cfg ):
327353    """This is a function to dump quantizable ops of model to user. 
328354
329355    Args: 
330-         tune_cfg  (dict): quantization config 
356+         user_cfg  (dict): quantization config 
331357    Returns: 
332358        None 
333359    """ 
334360    res  =  dict ()
335-     for  k , v  in  tune_cfg [ "op" ] .items ():
361+     for  k , v  in  user_cfg .items ():
336362        op_type_list  =  k [- 1 ].split ("><" )
337363        op_type  =  "" 
338364        for  op  in  op_type_list :
0 commit comments