Skip to content

Commit a3e4789

Browse files
mixed sparsity and dense graph cleancode (#234)
* modify graph.py * remove inner product #vnni * cleancode * fix the pytest * fix pytest Co-authored-by: Bo Dong <[email protected]>
1 parent f56eeb6 commit a3e4789

10 files changed

+121
-1845
lines changed

nlp_toolkit/backends/neural_engine/compile/graph/graph.py

Lines changed: 120 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
class Graph(object):
27+
2728
def __init__(self):
2829
self._nodes = []
2930
self._node_id = {}
@@ -483,6 +484,7 @@ def dict_representer(dumper, data):
483484
logger.info("Emit done...")
484485

485486
def get_sparse_nodes_name(self, threshold=0.7):
487+
486488
def get_zero_ratio(matrix, block):
487489
sparse_ratio = -1
488490
if matrix.ndim == 2 and len(block) == 2:
@@ -526,45 +528,44 @@ def get_zero_ratio(matrix, block):
526528
sparse_nodes_name.append(node.name)
527529
return sparse_nodes_name
528530

529-
530-
def innerproduct_type_check(self, node):
531-
innerproduct_type = {
532-
# general_node
533-
"general": "general",
534-
535-
# InnerProduct Nodes
536-
"QKV_innerproduct": 'add_innerproduct_0',
537-
'output_dense_bias': 'add_innerproduct_1',
538-
'intermediate_dense_mul': 'mul_innerproduct_0',
539-
540-
# Matmul Nodes
541-
'add_matmul': 'add_matmul_0',
542-
'transpose_matmul': 'transpose_matmul_0',
543-
}
544-
if node.name.startswith('Add') and node.op_type == "InnerProduct":
545-
if 'append_op' in node.attr:
546-
return innerproduct_type['output_dense_bias']
547-
else:
548-
return innerproduct_type['QKV_innerproduct']
549-
550-
if node.name.startswith('Mul') and node.op_type == "InnerProduct":
551-
return innerproduct_type['intermediate_dense_mul']
552-
553-
if node.name.startswith('Add') and node.op_type == 'Matmul':
554-
return innerproduct_type['add_matmul']
555-
556-
if node.name.startswith('Transpose') and node.op_type == 'Matmul':
557-
return innerproduct_type['transpose_matmul']
558-
else:
559-
return innerproduct_type['general']
560-
561531
def transpose_mode_int8(self, node_name_list=None):
562532
from ..ops import Tensor
563533
from .. import graph_utils as util
564534
import copy
565535
logger.info("Start to transpose_mode_int8 ......")
566536
reorder_dict = {}
567537

538+
def innerproduct_type_check(node):
539+
innerproduct_type = {
540+
# general_node
541+
"general": "general",
542+
543+
# InnerProduct Nodes
544+
"QKV_innerproduct": 'add_innerproduct_0',
545+
'output_dense_bias': 'add_innerproduct_1',
546+
'intermediate_dense_mul': 'mul_innerproduct_0',
547+
548+
# Matmul Nodes
549+
'add_matmul': 'add_matmul_0',
550+
'transpose_matmul': 'transpose_matmul_0',
551+
}
552+
if node.name.startswith('Add') and node.op_type == "InnerProduct":
553+
if 'append_op' in node.attr:
554+
return innerproduct_type['output_dense_bias']
555+
else:
556+
return innerproduct_type['QKV_innerproduct']
557+
558+
if node.name.startswith('Mul') and node.op_type == "InnerProduct":
559+
return innerproduct_type['intermediate_dense_mul']
560+
561+
if node.name.startswith('Add') and node.op_type == 'Matmul':
562+
return innerproduct_type['add_matmul']
563+
564+
if node.name.startswith('Transpose') and node.op_type == 'Matmul':
565+
return innerproduct_type['transpose_matmul']
566+
else:
567+
return innerproduct_type['general']
568+
568569
def create_new_attr(node):
569570
if 'output_dtype' in node.attr:
570571
new_attr = OrderedDict({
@@ -702,27 +703,39 @@ def modify_post_node_input_tensor(post_node, node, node_input_tensors_idx=0):
702703
logger.info("The node_name_list is None. Start to get sparse nodes name...")
703704
node_name_list = self.get_sparse_nodes_name()
704705

705-
pattern = '0001'
706-
node_name_type_list = ''
707-
for node_name in node_name_list:
708-
node = self.get_node_by_name(node_name)
709-
node_type = self.innerproduct_type_check(node)
710-
if node_type == 'add_innerproduct_0':
711-
node_name_type_list += '0'
712-
if node_type == 'add_innerproduct_1':
713-
node_name_type_list += '1'
714-
if node_type == 'mul_innerproduct_0':
715-
node_name_type_list += '2'
706+
def node_name_list_convert(node_name_list):
707+
node_name_type_list = ''
708+
for node_name in node_name_list:
709+
node = self.get_node_by_name(node_name)
710+
node_type = innerproduct_type_check(node)
711+
if node_type == 'add_innerproduct_0':
712+
node_name_type_list += '0'
713+
if node_type == 'add_innerproduct_1':
714+
node_name_type_list += '1'
715+
if node_type == 'mul_innerproduct_0':
716+
node_name_type_list += '2'
717+
return node_name_type_list
718+
719+
node_name_type_list = node_name_list_convert(node_name_list)
720+
721+
def patthen_match(node_name_type_list, pattern='0001'):
722+
matched_idx = []
723+
idx = node_name_type_list.find(pattern)
724+
while idx != -1:
725+
matched_idx.append(idx)
726+
idx = node_name_type_list.find(pattern, idx + 1)
727+
728+
tmp_node_name_list = []
729+
for node_idx in matched_idx:
730+
tmp_node_name_list.append(node_name_list[node_idx:node_idx + 4])
731+
QKV_node_name_list = [i for item in tmp_node_name_list for i in item]
732+
return matched_idx, QKV_node_name_list
716733

717-
matched_idx = []
718-
idx = node_name_type_list.find(pattern)
719-
while idx != -1:
720-
matched_idx.append(idx)
721-
idx = node_name_type_list[idx + 1:].find(pattern, idx)
734+
matched_idx, QKV_node_name_list = patthen_match(node_name_type_list)
722735

723736
for node_name in node_name_list:
724737
node = self.get_node_by_name(node_name)
725-
node_type = self.innerproduct_type_check(node)
738+
node_type = innerproduct_type_check(node)
726739
if node_type == 'general':
727740
continue
728741

@@ -784,39 +797,57 @@ def expand_gelu_tanh(node):
784797
for node_name in reorder_dict:
785798
insert_idx = self.get_node_id(node_name)
786799
self.insert_nodes(insert_idx, [reorder_dict[node_name]])
787-
'''
788-
Fusion 1: eliminate the two reorder nodes when a tensor passes through reorder_recover + reorder_post consecutively
789-
'''
790-
for node in self._nodes:
791-
if 'Reorder_Post' in node.name and node.op_type == 'Reorder' and 'recover_reorder' in node.output_tensors[
792-
0].name:
793-
pre_node = self.get_node_by_name(node.input_tensors[0].source_op[0])
794-
if 'Reorder_Recover' in pre_node.name and pre_node.op_type == 'Reorder':
795-
post_node = self.get_node_by_name(node.output_tensors[0].dest_op[0])
796-
idx = 0
797-
for _ in post_node.input_tensors:
798-
if node.output_tensors[0].name == _.name:
799-
break
800-
else:
801-
idx = idx + 1
802-
post_node.input_tensors[idx] = pre_node.input_tensors[0]
803-
self.remove_nodes([pre_node.name, node.name])
800+
801+
def consecutive_reorder_fusion():
802+
'''
803+
Fusion 1:
804+
eliminate the two reorder nodes if a tensor passes through reorder_recover + reorder_post consecutively
805+
'''
806+
for node in self._nodes:
807+
if node.op_type == 'Reorder' and 'recover_reorder' in node.output_tensors[0].name:
808+
if 'Reorder_Post' in node.name:
809+
pre_node = self.get_node_by_name(node.input_tensors[0].source_op[0])
810+
if 'Reorder_Recover' in pre_node.name and pre_node.op_type == 'Reorder':
811+
post_node = self.get_node_by_name(node.output_tensors[0].dest_op[0])
812+
idx = 0
813+
for _ in post_node.input_tensors:
814+
if node.output_tensors[0].name == _.name:
815+
break
816+
else:
817+
idx = idx + 1
818+
post_node.input_tensors[idx] = pre_node.input_tensors[0]
819+
self.remove_nodes([pre_node.name, node.name])
820+
821+
consecutive_reorder_fusion()
804822

805823
if matched_idx == []:
806824
logger.info("transpose_mode_int8 done. No QKV fusion")
807825
return
808-
else:
809-
tmp_node_name_list = []
810-
for node_idx in matched_idx:
811-
tmp_node_name_list.append(node_name_list[node_idx:node_idx + 4])
812-
QKV_node_name_list = [i for item in tmp_node_name_list for i in item]
813-
'''
814-
Fusion 2: Place reorder_post nodes before the quantize node, especially for QKV and output dense
815-
'''
816-
reorder_dict = {}
826+
'''
827+
Fusion 2:
828+
Place reorder_post nodes before the quantize node, especially for QKV and output dense
829+
'''
830+
reorder_dict = {}
831+
832+
def reorder_post_fusion():
833+
834+
def check_QKV_fusion(node):
835+
post_node = self.get_node_by_name(node.output_tensors[0].dest_op[0])
836+
node_type = innerproduct_type_check(post_node)
837+
if node_type == 'mul_innerproduct_0':
838+
return True
839+
for post_node_name in node.output_tensors[0].dest_op:
840+
if post_node_name in QKV_node_name_list:
841+
continue
842+
else:
843+
return False
844+
return True
845+
817846
for node in self._nodes:
818-
if node.op_type == 'Reorder':
847+
if node.op_type == 'Reorder' and 'Reorder_Post' in node.name:
819848
pre_node = self.get_node_by_name(node.input_tensors[0].source_op[0])
849+
if check_QKV_fusion(node) == False:
850+
continue
820851
if pre_node.op_type == 'Quantize':
821852
# swap the pre_node and the current node
822853
for post_node_name in node.output_tensors[0].dest_op:
@@ -826,11 +857,9 @@ def expand_gelu_tanh(node):
826857
self.remove_nodes([node.name])
827858
reorder_node = reorder_node_insert(pre_node, 0)
828859
layernorm_node = self.get_node_by_name(reorder_node.input_tensors[0].source_op[0])
829-
830860
# if the following node is reorder_post node, delete the Add_129_Reorder_Post_3
831861
if 'Reorder_Post' in layernorm_node.output_tensors[0].dest_op[0]:
832862
tmp = self.get_node_by_name(layernorm_node.output_tensors[0].dest_op[0])
833-
# Add_129
834863
post_node = self.get_node_by_name(tmp.output_tensors[0].dest_op[0])
835864
self.remove_nodes([tmp.name])
836865
layernorm_node.output_tensors[0].dest_op.append(post_node.name)
@@ -841,12 +870,18 @@ def expand_gelu_tanh(node):
841870
insert_idx = self.get_node_id(node_name)
842871
self.insert_nodes(insert_idx, [reorder_dict[node_name]])
843872

873+
reorder_post_fusion()
874+
875+
def reorder_recover_fusion():
876+
'''
877+
Fusion 3: place the reorder_recover nodes after reshape and matmul nodes
878+
'''
844879
for node in self._nodes:
845-
# Fusion 3: place the reorder_recover after reshape and matmul nodes
846880
# step1: delte all recover nodes of innerProduct nodes and modify reshape inputs
847-
node_type = self.innerproduct_type_check(node)
848-
if node_type == 'add_innerproduct_0' and node in QKV_node_name_list:
881+
node_type = innerproduct_type_check(node)
882+
if node_type == 'add_innerproduct_0' and node.name in QKV_node_name_list:
849883
post_node = self.get_node_by_name(node.output_tensors[0].dest_op[0])
884+
850885
if 'Reorder_Recover' in post_node.name:
851886
reshape_node = self.get_node_by_name(post_node.output_tensors[0].dest_op[0])
852887
if reshape_node.op_type == 'Reshape':
@@ -863,7 +898,7 @@ def expand_gelu_tanh(node):
863898

864899
# step2 : modify add_matmul and transpose_matmul nodes
865900
post_reshape_node = self.get_node_by_name(reshape_node.output_tensors[0].dest_op[0])
866-
if self.innerproduct_type_check(post_reshape_node) == 'add_matmul_0':
901+
if innerproduct_type_check(post_reshape_node) == 'add_matmul_0':
867902

868903
def add_matmul_modification(node):
869904
if node.attr.get("src0_perm") == '0,2,1,3' and node.attr.get("src1_perm") == '0,2,3,1':
@@ -875,7 +910,7 @@ def add_matmul_modification(node):
875910

876911
add_matmul_modification(post_reshape_node)
877912

878-
if self.innerproduct_type_check(post_reshape_node) == 'transpose_matmul_0':
913+
if innerproduct_type_check(post_reshape_node) == 'transpose_matmul_0':
879914

880915
def transpose_matmul_modification(node):
881916
if node.attr.get("src0_perm") == '0,2,1,3' and node.attr.get("src1_perm") == '0,2,3,1':
@@ -896,4 +931,6 @@ def transpose_matmul_modification(node):
896931

897932
transpose_matmul_modification(post_reshape_node)
898933

934+
reorder_recover_fusion()
935+
899936
logger.info("transpose_mode_int8 done")

nlp_toolkit/backends/neural_engine/executor/src/operators/inner_product.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,6 @@ void InnerProductOperator::ReshapeSparseLib(const vector<Tensor*>& input, const
546546

547547
#if __AVX512F__
548548
void InnerProductOperator::ForwardSparseLib(const vector<Tensor*>& input, const vector<Tensor*>& output) {
549-
#if __AVX512VNNI__
550549
// reorder 2d to 3d
551550
// 2D dense: [256, 768] x [768 328] -> [256, 328]
552551
// 2D sparselib: [328, 768] x [768, 256] -> [328, 256]
@@ -588,6 +587,7 @@ void InnerProductOperator::ForwardSparseLib(const vector<Tensor*>& input, const
588587
std::vector<const void*> runtime_data = {src0_->data(), src1_->data(), has_bias_ ? bias_->data() : nullptr, dst_data,
589588
rescales_.data()};
590589
spmm_kern_.execute(runtime_data);
590+
591591
// reorder dst activation (optional)
592592
if (dispatch_from_ == "InnerProduct" && !dispatch_config_.empty() && dispatch_config_[0] == "SparseLib") {
593593
// reorder to 3D then reshape
@@ -598,7 +598,6 @@ void InnerProductOperator::ForwardSparseLib(const vector<Tensor*>& input, const
598598
if (dispatch_from_ == "InnerProduct" && !dispatch_config_.empty()
599599
&& dispatch_config_[0] == "SparseLib" && life_count > 1) post_->reorder(src1_3d_shape);
600600
}
601-
#endif
602601
}
603602
#endif
604603

0 commit comments

Comments
 (0)