4343 "<class 'torch.nn.modules.conv.Conv2d'>" : "Conv2d" ,
4444 "<class 'torch.nn.modules.conv.Conv3d'>" : "Conv3d" ,
4545 "<class 'torch.nn.modules.activation.ReLU'>" : "ReLU" ,
46+ "<class 'torch.nn.modules.sparse.EmbeddingBag'>" : "EmbeddingBag" ,
4647 "<method 'add' of 'torch._C._TensorBase' objects>" : "add" , # for IPEX < 2.2
4748 "<method 'add' of 'torch._C.TensorBase' objects>" : "add" , # for IPEX >= 2.2
4849 "<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>" : "AdaptiveAvgPool2d" ,
4950 "Linear_Relu" : "Linear" ,
51+ "Linear_add" : "Linear" ,
5052 "<class 'torch.nn.modules.linear.Linear'>" : "Linear" ,
5153 "<class 'torch.nn.modules.pooling.MaxPool2d'>" : "MaxPool2d" ,
52- "re" : {"<built-in method matmul of type object at" : "matmul" },
54+ "re" : {
55+ "<built-in method matmul of type object at" : "matmul" ,
56+ "<built-in method add of type object at" : "add" ,
57+ "<built-in method bmm of type object at" : "bmm" ,
58+ },
5359}
5460
5561BLOCK_PATTERNS = [
@@ -85,6 +91,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
8591 Returns:
8692 cfgs (dict): updated configs.
8793 """
94+ ori_user_cfg = copy .deepcopy (user_cfg )
8895 tmp_user_cfg = OrderedDict ()
8996 for op in user_cfg : # map ipex op_name to pt op_name
9097 for i , op_name in enumerate (op ):
@@ -94,9 +101,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
94101 ori_op = (tuple (ops ), unify_op_type_mapping_ipex [op_infos_from_cfgs [ops ]["op_type" ]])
95102 tmp_user_cfg [((ori_op [0 ],), ori_op [1 ])] = user_cfg [op ]
96103 break
97- user_cfg = tmp_user_cfg
98- for op_name in user_cfg :
99- inc_op_cfg = user_cfg [op_name ]
104+
105+ for op_name in tmp_user_cfg :
106+ inc_op_cfg = tmp_user_cfg [op_name ]
100107 for i , name in enumerate (op_name [0 ]):
101108 # to int8
102109 ipex_op_cfg = op_infos_from_cfgs [name ]
@@ -154,7 +161,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
154161 else :
155162 pass
156163 cfgs [name [0 ]][name [1 ]][name [2 ]] = ipex_op_cfg
157- return cfgs , user_cfg
164+ return cfgs , ori_user_cfg
158165
159166
160167def generate_activation_observer (scheme , algorithm , smooth_quant = False , smooth_quant_enable = False ): # pragma: no cover
@@ -333,8 +340,8 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
333340 elif "method" in ipex_op_type : # "<method 'add' of 'torch._C._TensorBase' objects>"
334341 method = ipex_op_type .split ("'" )[1 ]
335342 op_name_info .append ((module_fqn , method ))
336- elif "Convolution " in ipex_op_type : # "Convolution_Relu"
337- op_name_info .append ((module_fqn , "Conv2d" ))
343+ elif "_ " in ipex_op_type : # "Convolution_Relu", "Linear_Relu "
344+ op_name_info .append ((module_fqn , ipex_op_type . split ( "_" )[ 0 ] ))
338345 else :
339346 re_flag = False
340347 for pattern , unify_op_type in unify_op_type_mapping_ipex ["re" ].items ():
@@ -394,32 +401,7 @@ def dump_model_op_stats(user_cfg):
394401 """
395402 res = dict ()
396403 for k , v in user_cfg .items ():
397- op_type_list = k [- 1 ].split ("><" )
398- op_type = ""
399- for op in op_type_list :
400- if "class" in op :
401- op_type = (
402- op [op .rfind ("." ) + 1 : op .rfind ("'" )]
403- if op_type == ""
404- else op_type + "&" + op [op .rfind ("." ) + 1 : op .rfind ("'" )]
405- )
406- elif "method" in op :
407- start = op .find ("'" ) + 1
408- if start > 1 :
409- op_type = (
410- op [start : op .find ("'" , start )]
411- if op_type == ""
412- else op_type + "&" + op [start : op .find ("'" , start )]
413- )
414- else :
415- start = op .find ("method" ) + 7
416- op_type = (
417- op [start : op .find (" " , start )]
418- if op_type == ""
419- else op_type + "&" + op [start : op .find (" " , start )]
420- )
421- else :
422- op_type = op if op_type == "" else op_type + "&" + op
404+ op_type = k [1 ]
423405 if op_type not in res .keys ():
424406 res [op_type ] = {"INT8" : 0 , "BF16" : 0 , "FP32" : 0 }
425407 if v ["weight" ]["dtype" ] == "int8" :
0 commit comments