2424
2525
2626class Graph (object ):
27+
2728 def __init__ (self ):
2829 self ._nodes = []
2930 self ._node_id = {}
@@ -483,6 +484,7 @@ def dict_representer(dumper, data):
483484 logger .info ("Emit done..." )
484485
485486 def get_sparse_nodes_name (self , threshold = 0.7 ):
487+
486488 def get_zero_ratio (matrix , block ):
487489 sparse_ratio = - 1
488490 if matrix .ndim == 2 and len (block ) == 2 :
@@ -526,45 +528,44 @@ def get_zero_ratio(matrix, block):
526528 sparse_nodes_name .append (node .name )
527529 return sparse_nodes_name
528530
529-
530- def innerproduct_type_check (self , node ):
531- innerproduct_type = {
532- # general_node
533- "general" : "general" ,
534-
535- # InnerProduct Nodes
536- "QKV_innerproduct" : 'add_innerproduct_0' ,
537- 'output_dense_bias' : 'add_innerproduct_1' ,
538- 'intermediate_dense_mul' : 'mul_innerproduct_0' ,
539-
540- # Matmul Nodes
541- 'add_matmul' : 'add_matmul_0' ,
542- 'transpose_matmul' : 'transpose_matmul_0' ,
543- }
544- if node .name .startswith ('Add' ) and node .op_type == "InnerProduct" :
545- if 'append_op' in node .attr :
546- return innerproduct_type ['output_dense_bias' ]
547- else :
548- return innerproduct_type ['QKV_innerproduct' ]
549-
550- if node .name .startswith ('Mul' ) and node .op_type == "InnerProduct" :
551- return innerproduct_type ['intermediate_dense_mul' ]
552-
553- if node .name .startswith ('Add' ) and node .op_type == 'Matmul' :
554- return innerproduct_type ['add_matmul' ]
555-
556- if node .name .startswith ('Transpose' ) and node .op_type == 'Matmul' :
557- return innerproduct_type ['transpose_matmul' ]
558- else :
559- return innerproduct_type ['general' ]
560-
561531 def transpose_mode_int8 (self , node_name_list = None ):
562532 from ..ops import Tensor
563533 from .. import graph_utils as util
564534 import copy
565535 logger .info ("Start to transpose_mode_int8 ......" )
566536 reorder_dict = {}
567537
538+ def innerproduct_type_check (node ):
539+ innerproduct_type = {
540+ # general_node
541+ "general" : "general" ,
542+
543+ # InnerProduct Nodes
544+ "QKV_innerproduct" : 'add_innerproduct_0' ,
545+ 'output_dense_bias' : 'add_innerproduct_1' ,
546+ 'intermediate_dense_mul' : 'mul_innerproduct_0' ,
547+
548+ # Matmul Nodes
549+ 'add_matmul' : 'add_matmul_0' ,
550+ 'transpose_matmul' : 'transpose_matmul_0' ,
551+ }
552+ if node .name .startswith ('Add' ) and node .op_type == "InnerProduct" :
553+ if 'append_op' in node .attr :
554+ return innerproduct_type ['output_dense_bias' ]
555+ else :
556+ return innerproduct_type ['QKV_innerproduct' ]
557+
558+ if node .name .startswith ('Mul' ) and node .op_type == "InnerProduct" :
559+ return innerproduct_type ['intermediate_dense_mul' ]
560+
561+ if node .name .startswith ('Add' ) and node .op_type == 'Matmul' :
562+ return innerproduct_type ['add_matmul' ]
563+
564+ if node .name .startswith ('Transpose' ) and node .op_type == 'Matmul' :
565+ return innerproduct_type ['transpose_matmul' ]
566+ else :
567+ return innerproduct_type ['general' ]
568+
568569 def create_new_attr (node ):
569570 if 'output_dtype' in node .attr :
570571 new_attr = OrderedDict ({
@@ -702,27 +703,39 @@ def modify_post_node_input_tensor(post_node, node, node_input_tensors_idx=0):
702703 logger .info ("The node_name_list is None. Start to get sparse nodes name..." )
703704 node_name_list = self .get_sparse_nodes_name ()
704705
705- pattern = '0001'
706- node_name_type_list = ''
707- for node_name in node_name_list :
708- node = self .get_node_by_name (node_name )
709- node_type = self .innerproduct_type_check (node )
710- if node_type == 'add_innerproduct_0' :
711- node_name_type_list += '0'
712- if node_type == 'add_innerproduct_1' :
713- node_name_type_list += '1'
714- if node_type == 'mul_innerproduct_0' :
715- node_name_type_list += '2'
706+ def node_name_list_convert (node_name_list ):
707+ node_name_type_list = ''
708+ for node_name in node_name_list :
709+ node = self .get_node_by_name (node_name )
710+ node_type = innerproduct_type_check (node )
711+ if node_type == 'add_innerproduct_0' :
712+ node_name_type_list += '0'
713+ if node_type == 'add_innerproduct_1' :
714+ node_name_type_list += '1'
715+ if node_type == 'mul_innerproduct_0' :
716+ node_name_type_list += '2'
717+ return node_name_type_list
718+
719+ node_name_type_list = node_name_list_convert (node_name_list )
720+
721+ def patthen_match (node_name_type_list , pattern = '0001' ):
722+ matched_idx = []
723+ idx = node_name_type_list .find (pattern )
724+ while idx != - 1 :
725+ matched_idx .append (idx )
726+ idx = node_name_type_list .find (pattern , idx + 1 )
727+
728+ tmp_node_name_list = []
729+ for node_idx in matched_idx :
730+ tmp_node_name_list .append (node_name_list [node_idx :node_idx + 4 ])
731+ QKV_node_name_list = [i for item in tmp_node_name_list for i in item ]
732+ return matched_idx , QKV_node_name_list
716733
717- matched_idx = []
718- idx = node_name_type_list .find (pattern )
719- while idx != - 1 :
720- matched_idx .append (idx )
721- idx = node_name_type_list [idx + 1 :].find (pattern , idx )
734+ matched_idx , QKV_node_name_list = patthen_match (node_name_type_list )
722735
723736 for node_name in node_name_list :
724737 node = self .get_node_by_name (node_name )
725- node_type = self . innerproduct_type_check (node )
738+ node_type = innerproduct_type_check (node )
726739 if node_type == 'general' :
727740 continue
728741
@@ -784,39 +797,57 @@ def expand_gelu_tanh(node):
784797 for node_name in reorder_dict :
785798 insert_idx = self .get_node_id (node_name )
786799 self .insert_nodes (insert_idx , [reorder_dict [node_name ]])
787- '''
788- Fusion 1: eliminate the two reorder nodes when a tensor passes through reorder_recover + reorder_post consecutively
789- '''
790- for node in self ._nodes :
791- if 'Reorder_Post' in node .name and node .op_type == 'Reorder' and 'recover_reorder' in node .output_tensors [
792- 0 ].name :
793- pre_node = self .get_node_by_name (node .input_tensors [0 ].source_op [0 ])
794- if 'Reorder_Recover' in pre_node .name and pre_node .op_type == 'Reorder' :
795- post_node = self .get_node_by_name (node .output_tensors [0 ].dest_op [0 ])
796- idx = 0
797- for _ in post_node .input_tensors :
798- if node .output_tensors [0 ].name == _ .name :
799- break
800- else :
801- idx = idx + 1
802- post_node .input_tensors [idx ] = pre_node .input_tensors [0 ]
803- self .remove_nodes ([pre_node .name , node .name ])
800+
801+ def consecutive_reorder_fusion ():
802+ '''
803+ Fusion 1:
804+ eliminate the two reorder nodes if a tensor passes through reorder_recover + reorder_post consecutively
805+ '''
806+ for node in self ._nodes :
807+ if node .op_type == 'Reorder' and 'recover_reorder' in node .output_tensors [0 ].name :
808+ if 'Reorder_Post' in node .name :
809+ pre_node = self .get_node_by_name (node .input_tensors [0 ].source_op [0 ])
810+ if 'Reorder_Recover' in pre_node .name and pre_node .op_type == 'Reorder' :
811+ post_node = self .get_node_by_name (node .output_tensors [0 ].dest_op [0 ])
812+ idx = 0
813+ for _ in post_node .input_tensors :
814+ if node .output_tensors [0 ].name == _ .name :
815+ break
816+ else :
817+ idx = idx + 1
818+ post_node .input_tensors [idx ] = pre_node .input_tensors [0 ]
819+ self .remove_nodes ([pre_node .name , node .name ])
820+
821+ consecutive_reorder_fusion ()
804822
805823 if matched_idx == []:
806824 logger .info ("transpose_mode_int8 done. No QKV fusion" )
807825 return
808- else :
809- tmp_node_name_list = []
810- for node_idx in matched_idx :
811- tmp_node_name_list .append (node_name_list [node_idx :node_idx + 4 ])
812- QKV_node_name_list = [i for item in tmp_node_name_list for i in item ]
813- '''
814- Fusion 2: Place reorder_post nodes before the quantize node, especially for QKV and output dense
815- '''
816- reorder_dict = {}
826+ '''
827+ Fusion 2:
828+ Place reorder_post nodes before the quantize node, especially for QKV and output dense
829+ '''
830+ reorder_dict = {}
831+
832+ def reorder_post_fusion ():
833+
834+ def check_QKV_fusion (node ):
835+ post_node = self .get_node_by_name (node .output_tensors [0 ].dest_op [0 ])
836+ node_type = innerproduct_type_check (post_node )
837+ if node_type == 'mul_innerproduct_0' :
838+ return True
839+ for post_node_name in node .output_tensors [0 ].dest_op :
840+ if post_node_name in QKV_node_name_list :
841+ continue
842+ else :
843+ return False
844+ return True
845+
817846 for node in self ._nodes :
818- if node .op_type == 'Reorder' :
847+ if node .op_type == 'Reorder' and 'Reorder_Post' in node . name :
819848 pre_node = self .get_node_by_name (node .input_tensors [0 ].source_op [0 ])
849+ if check_QKV_fusion (node ) == False :
850+ continue
820851 if pre_node .op_type == 'Quantize' :
821852 # swap the pre_node and the current node
822853 for post_node_name in node .output_tensors [0 ].dest_op :
@@ -826,11 +857,9 @@ def expand_gelu_tanh(node):
826857 self .remove_nodes ([node .name ])
827858 reorder_node = reorder_node_insert (pre_node , 0 )
828859 layernorm_node = self .get_node_by_name (reorder_node .input_tensors [0 ].source_op [0 ])
829-
830860 # if the following node is reorder_post node, delete the Add_129_Reorder_Post_3
831861 if 'Reorder_Post' in layernorm_node .output_tensors [0 ].dest_op [0 ]:
832862 tmp = self .get_node_by_name (layernorm_node .output_tensors [0 ].dest_op [0 ])
833- # Add_129
834863 post_node = self .get_node_by_name (tmp .output_tensors [0 ].dest_op [0 ])
835864 self .remove_nodes ([tmp .name ])
836865 layernorm_node .output_tensors [0 ].dest_op .append (post_node .name )
@@ -841,12 +870,18 @@ def expand_gelu_tanh(node):
841870 insert_idx = self .get_node_id (node_name )
842871 self .insert_nodes (insert_idx , [reorder_dict [node_name ]])
843872
873+ reorder_post_fusion ()
874+
875+ def reorder_recover_fusion ():
876+ '''
877+ Fusion 3: place the reorder_recover nodes after reshape and matmul nodes
878+ '''
844879 for node in self ._nodes :
845- # Fusion 3: place the reorder_recover after reshape and matmul nodes
846880 # step1: delte all recover nodes of innerProduct nodes and modify reshape inputs
847- node_type = self . innerproduct_type_check (node )
848- if node_type == 'add_innerproduct_0' and node in QKV_node_name_list :
881+ node_type = innerproduct_type_check (node )
882+ if node_type == 'add_innerproduct_0' and node . name in QKV_node_name_list :
849883 post_node = self .get_node_by_name (node .output_tensors [0 ].dest_op [0 ])
884+
850885 if 'Reorder_Recover' in post_node .name :
851886 reshape_node = self .get_node_by_name (post_node .output_tensors [0 ].dest_op [0 ])
852887 if reshape_node .op_type == 'Reshape' :
@@ -863,7 +898,7 @@ def expand_gelu_tanh(node):
863898
864899 # step2 : modify add_matmul and transpose_matmul nodes
865900 post_reshape_node = self .get_node_by_name (reshape_node .output_tensors [0 ].dest_op [0 ])
866- if self . innerproduct_type_check (post_reshape_node ) == 'add_matmul_0' :
901+ if innerproduct_type_check (post_reshape_node ) == 'add_matmul_0' :
867902
868903 def add_matmul_modification (node ):
869904 if node .attr .get ("src0_perm" ) == '0,2,1,3' and node .attr .get ("src1_perm" ) == '0,2,3,1' :
@@ -875,7 +910,7 @@ def add_matmul_modification(node):
875910
876911 add_matmul_modification (post_reshape_node )
877912
878- if self . innerproduct_type_check (post_reshape_node ) == 'transpose_matmul_0' :
913+ if innerproduct_type_check (post_reshape_node ) == 'transpose_matmul_0' :
879914
880915 def transpose_matmul_modification (node ):
881916 if node .attr .get ("src0_perm" ) == '0,2,1,3' and node .attr .get ("src1_perm" ) == '0,2,3,1' :
@@ -896,4 +931,6 @@ def transpose_matmul_modification(node):
896931
897932 transpose_matmul_modification (post_reshape_node )
898933
934+ reorder_recover_fusion ()
935+
899936 logger .info ("transpose_mode_int8 done" )
0 commit comments