Skip to content

Commit f70394c

Browse files
authored
add patterns tuning and graph tuning (#226)
1 parent a3e4789 commit f70394c

File tree

5 files changed

+364
-29
lines changed

5 files changed

+364
-29
lines changed

nlp_toolkit/backends/neural_engine/compile/graph/graph.py

Lines changed: 169 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import numpy as np
2222
import yaml
2323
import os
24-
24+
import copy
25+
import time
2526

2627
class Graph(object):
2728

@@ -483,7 +484,94 @@ def dict_representer(dumper, data):
483484

484485
logger.info("Emit done...")
485486

486-
def get_sparse_nodes_name(self, threshold=0.7):
487+
def graph_dispatch(self, tune = True, inputs_shape = []):
488+
sparse_nodes_name = self.get_sparse_nodes_name()
489+
if tune:
490+
logger.info("Tuning graph start ...")
491+
self._tune_onednn_graph(inputs_shape)
492+
self._tune_sparse_graph(inputs_shape, sparse_nodes_name)
493+
logger.info("Tuning graph end ...")
494+
else:
495+
# if not tune, map to sparse graph directly
496+
self.transpose_mode_int8(sparse_nodes_name)
497+
498+
def _tune_onednn_graph(self, inputs_shape = []):
499+
onednn_graph_nodes_map = self._get_onednn_graph_nodes()
500+
if onednn_graph_nodes_map == {"InnerProduct": [], "Softmax": []}:
501+
pass
502+
else:
503+
onednn_graph_nodes_name_list = self._generate_onednn_graph_nodes_name_list(onednn_graph_nodes_map)
504+
golden_onednn_graph_nodes_name = []
505+
min_latency = float("inf")
506+
for onednn_graph_nodes_name in onednn_graph_nodes_name_list:
507+
curr_latency = float("inf")
508+
try:
509+
curr_model = copy.deepcopy(self)
510+
curr_model._generate_onednn_graph_nodes(onednn_graph_nodes_name)
511+
curr_result, curr_latency = curr_model._get_latency(inputs_shape)
512+
except:
513+
logger.warning("Graph can not be inferenced, please check the graph!")
514+
# update min latency and transpose nodes name
515+
if curr_latency < min_latency:
516+
min_latency = curr_latency
517+
golden_onednn_graph_nodes_name = onednn_graph_nodes_name
518+
self._generate_onednn_graph_nodes(golden_onednn_graph_nodes_name)
519+
520+
def _get_onednn_graph_nodes(self):
521+
# onednn graph only support fp32 inner_product and softmax
522+
onednn_graph_nodes_map = {"InnerProduct": [], "Softmax": []}
523+
for node in self.nodes:
524+
if node.op_type == "InnerProduct":
525+
weight = node.input_tensors[1]
526+
if type(weight.data) == np.ndarray and \
527+
weight.data.dtype == "float32":
528+
onednn_graph_nodes_map["InnerProduct"].append(node.name)
529+
elif node.op_type == "Softmax":
530+
if node.attr.get("output_dtype", "float32") == "float32":
531+
onednn_graph_nodes_map["Softmax"].append(node.name)
532+
return onednn_graph_nodes_map
533+
534+
def _generate_onednn_graph_nodes_name_list(self, onednn_graph_nodes_map):
535+
# strategy:
536+
# 1.softmax: all nodes map to onednn graph or not
537+
# 2.innerproduct: tune accorording weight shape
538+
ip_nodes_name_list = self._generate_transpose_nodes_name_list(onednn_graph_nodes_map["InnerProduct"])
539+
onednn_graph_nodes_name_list = []
540+
for ip_nodes_name in ip_nodes_name_list:
541+
onednn_graph_nodes_name_list.append(ip_nodes_name)
542+
onednn_graph_nodes_name_list.append(ip_nodes_name + onednn_graph_nodes_map["Softmax"])
543+
return onednn_graph_nodes_name_list
544+
545+
def _generate_onednn_graph_nodes(self, onednn_graph_nodes_name):
546+
for node in self.nodes:
547+
if node.name in onednn_graph_nodes_name:
548+
if node.op_type == "InnerProduct":
549+
node.op_type = "InnerProductGraph"
550+
elif node.op_type == "Softmax":
551+
node.op_type = "SoftmaxGraph"
552+
553+
def _tune_sparse_graph(self, inputs_shape = [], sparse_nodes_name = []):
554+
if sparse_nodes_name == []:
555+
pass
556+
else:
557+
trans_nodes_name_list = self._generate_transpose_nodes_name_list(sparse_nodes_name)
558+
golden_trans_nodes_name = []
559+
min_latency = float("inf")
560+
for trans_nodes_name in trans_nodes_name_list:
561+
curr_latency = float("inf")
562+
try:
563+
curr_model = copy.deepcopy(self)
564+
curr_model.transpose_mode_int8(trans_nodes_name)
565+
curr_result, curr_latency = curr_model._get_latency(inputs_shape)
566+
except:
567+
logger.warning("Graph can not be inferenced, please check the graph!")
568+
# update min latency and transpose nodes name
569+
if curr_latency < min_latency:
570+
min_latency = curr_latency
571+
golden_trans_nodes_name = trans_nodes_name
572+
self.transpose_mode_int8(golden_trans_nodes_name)
573+
574+
def get_sparse_nodes_name(self, threshold = 0.7):
487575

488576
def get_zero_ratio(matrix, block):
489577
sparse_ratio = -1
@@ -496,9 +584,9 @@ def get_zero_ratio(matrix, block):
496584
is_zero_block = True
497585
for br in range(block[0]):
498586
for bc in range(block[1]):
499-
if matrix[mr * block[0] + br][mc * block[1] + bc] != 0:
500-
is_zero_block = False
501-
break
587+
if matrix[mr*block[0]+br][mc*block[1]+bc] != 0:
588+
is_zero_block = False
589+
break
502590
if not is_zero_block:
503591
break
504592
if is_zero_block == True:
@@ -511,7 +599,7 @@ def get_zero_ratio(matrix, block):
511599
if node.op_type == "InnerProduct":
512600
# sparse kernel limitation:
513601
# 1. int8
514-
# 2. sparse_ratio > 0.5(1*4)
602+
# 2. sparse_ratio > 0.7(1*4)
515603
# 3. output channel of weight_shape = 4x
516604
# 4. post op != tanh
517605
if 'append_op' not in node.attr \
@@ -521,12 +609,84 @@ def get_zero_ratio(matrix, block):
521609
if type(weight.data) == np.ndarray and \
522610
(weight.data.dtype == 'int8' \
523611
or weight.data.dtype == 'uint8') \
524-
and weight.data.shape[1] % 4 == 0:
525-
612+
and weight.data.shape[1] % 4 == 0: # 1*4 sparse block
526613
zero_ratio = get_zero_ratio(weight.data, [1, 4])
527614
if zero_ratio >= threshold:
528615
sparse_nodes_name.append(node.name)
529-
return sparse_nodes_name
616+
617+
return sparse_nodes_name
618+
619+
def _generate_transpose_nodes_name_list(self, sparse_nodes_name):
620+
transpose_nodes_list = []
621+
if sparse_nodes_name == []:
622+
return transpose_nodes_list
623+
# switch the nodes which has the same weight shape and pose op
624+
weight_shape_map = {}
625+
for node in self.nodes:
626+
if node.name in sparse_nodes_name:
627+
weight = node.input_tensors[1]
628+
weight_shape = tuple(weight.shape) # list to tuple for dict key
629+
if weight_shape in weight_shape_map.keys():
630+
weight_shape_map[weight_shape].append(node.name)
631+
else:
632+
weight_shape_map[weight_shape] = [node.name]
633+
634+
# binary reflected gray code to generate the all combinations fo the n elements
635+
def brgd(n):
636+
if n==1:
637+
return ["0","1"]
638+
L1 = brgd(n-1)
639+
L2 = copy.deepcopy(L1)
640+
L2.reverse()
641+
L1 = ["0" + l for l in L1]
642+
L2 = ["1" + l for l in L2]
643+
L = L1 + L2
644+
return L
645+
646+
transpose_mask_list = brgd(len(weight_shape_map))
647+
for transpose_mask in transpose_mask_list:
648+
transpose_nodes = []
649+
for idx, weight_shape in enumerate(weight_shape_map):
650+
if transpose_mask[idx]=="1":
651+
transpose_nodes += weight_shape_map[weight_shape]
652+
transpose_nodes_list.append(transpose_nodes)
653+
654+
return transpose_nodes_list
655+
656+
657+
def _generate_inputs(self, inputs_shape = []):
658+
dtype_map = {"float32": np.float32,
659+
"int8": np.int8,
660+
"int32": np.int32,
661+
"int64": np.int64,
662+
"uint8": np.uint8,
663+
}
664+
inputs = []
665+
id = 0
666+
for node in self.nodes:
667+
if node.op_type == "Input":
668+
for tensor in node.output_tensors:
669+
if not isinstance(tensor.data, np.ndarray):
670+
if inputs_shape == []:
671+
shape = [16 for s in tensor.shape if s == -1]
672+
else:
673+
shape = inputs_shape[id]
674+
dtype = dtype_map[tensor.dtype]
675+
input = np.random.uniform(low=0, high=10, size=shape).astype(dtype)
676+
inputs.append(input)
677+
id += 1
678+
return inputs
679+
680+
def _get_latency(self, inputs_shape = [], iterations = 10, warm_up = 5):
681+
inputs = self._generate_inputs(inputs_shape)
682+
iter_latency = []
683+
for _ in range(iterations):
684+
start_time = time.time()
685+
result = self.inference(inputs)
686+
end_time = time.time()
687+
iter_latency.append(end_time - start_time)
688+
latency = np.array(iter_latency[warm_up:]).mean()
689+
return result, latency
530690

531691
def transpose_mode_int8(self, node_name_list=None):
532692
from ..ops import Tensor

nlp_toolkit/backends/neural_engine/compile/sub_graph/pattern.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
'OutputData',
6161
]
6262

63+
# for superbert, superbert patterns are huge patterns based on supported patterns
64+
superbert_patterns = []
65+
6366
PATTERNS = {}
6467

6568

nlp_toolkit/backends/neural_engine/compile/sub_graph/subgraph_matcher.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from .pattern import supported_patterns, PATTERNS
18+
import time
19+
import copy
20+
import numpy as np
21+
from tqdm import tqdm
22+
from .pattern import supported_patterns, superbert_patterns, PATTERNS
1923
from .. import logger
2024

2125
EXECUTOR_TYPE = {
@@ -47,25 +51,65 @@
4751
"_MklLayerNorm": "LayerNorm",
4852
}
4953

50-
5154
class SubGraphMatcher(object):
52-
def __call__(self, model):
53-
patterns_switch = {
54-
'LayerNorm': True,
55-
'TransposeBatchMatMul': True,
56-
'MatMulWithBiasGelu': True,
57-
'MatMulWithBiasAdd': True,
58-
'MatMulWithBiasTanh': True,
59-
}
60-
logger.info('Start to implement Sub-Graph matching and replacing...')
61-
for pattern in supported_patterns:
62-
if pattern in PATTERNS:
63-
if pattern in patterns_switch.keys() and not patterns_switch[pattern]:
64-
continue
65-
else:
66-
p_fusion = PATTERNS[pattern]()
67-
model = p_fusion(model)
55+
def __call__(self, model, tune = False):
56+
logger.info('Start to implement Sub-Graph matching and replacing...')
57+
if tune:
58+
self._tune_patterns(model)
59+
else:
60+
self._fuse_patterns(model)
61+
logger.info('Sub-Graph match and replace done...')
62+
return model
6863

64+
def _fuse_patterns(self, model, supported_patterns=supported_patterns, pattern_mask=None):
65+
pattern_mask = [True for _ in range(len(supported_patterns))] \
66+
if pattern_mask == None else pattern_mask
67+
for pattern_id, pattern in enumerate(supported_patterns):
68+
if pattern in PATTERNS and pattern_mask[pattern_id]:
69+
p_fusion = PATTERNS[pattern]()
70+
model = p_fusion(model)
71+
self._remove_identity(model)
72+
73+
def _tune_patterns(self, model, iterations = 10, warm_up = 5):
74+
# pattern tuning strategy(for superbert):
75+
# 1. only one pattern off/on each time (pruning)
76+
# 2. check accuracy with framework
77+
# 3. and only save min latency config
78+
logger.info('Start tuning pattern...')
79+
all_patterns = supported_patterns + superbert_patterns
80+
pattern_mask = [True for i in range(len(all_patterns))]
81+
min_latency = float("inf")
82+
# skip tuning input node fusion and output node fusion
83+
for idx in tqdm(range(len(supported_patterns), len(all_patterns))):
84+
# pattern on
85+
on_latency = float("inf")
86+
try:
87+
on_model = copy.deepcopy(model)
88+
self._fuse_patterns(on_model, all_patterns, pattern_mask)
89+
on_result, on_latency = on_model._get_latency([], iterations, warm_up)
90+
except:
91+
logger.warning("Graph can not be inferenced, please check the graph!")
92+
# pattern off
93+
off_latency = float("inf")
94+
try:
95+
off_pattern_mask = copy.deepcopy(pattern_mask)
96+
off_pattern_mask[idx] = False
97+
off_model = copy.deepcopy(model)
98+
self._fuse_patterns(off_model, all_patterns, off_pattern_mask)
99+
off_result, off_latency = off_model._get_latency([], iterations, warm_up)
100+
except:
101+
logger.warning("Graph can not be inferenced, please check the graph!")
102+
# update min latency and pattern mask
103+
if off_latency < on_latency and off_latency < min_latency:
104+
min_latency = off_latency
105+
pattern_mask = off_pattern_mask
106+
107+
# generate model according pattern mask
108+
self._fuse_patterns(model, all_patterns, pattern_mask)
109+
logger.info('End tuning pattern...')
110+
return model
111+
112+
def _remove_identity(self, model):
69113
rm_node_names = []
70114
rm_op_type = ['Identity']
71115
for i in range(len(model.nodes)):
@@ -77,6 +121,4 @@ def __call__(self, model):
77121
op_type = EXECUTOR_TYPE[node.op_type]
78122
model.nodes[i].op_type = op_type
79123
model.remove_nodes(rm_node_names)
80-
logger.info('Sub-Graph match and replace done...')
81124

82-
return model
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2022 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import os
19+
import unittest
20+
import numpy as np
21+
from nlp_toolkit.backends.neural_engine.compile import compile
22+
23+
class TestGraphDispatch(unittest.TestCase):
24+
@classmethod
25+
def setUpClass(self):
26+
pass
27+
28+
@classmethod
29+
def tearDownClass(self):
30+
pass
31+
32+
def test_graph_dispatch(self):
33+
# set input data
34+
shape = [1, 128]
35+
input_0 = np.random.uniform(low=0, high=128, size=shape).astype('int32')
36+
input_1 = np.random.uniform(low=0, high=1, size=shape).astype('int32')
37+
input_2 = np.random.uniform(low=0, high=1, size=shape).astype('int32')
38+
39+
# validate int8 sparse graph tuning
40+
int8_model_path = "/home/tensorflow/inc_ut/engine/bert_mini_int8_original_IR"
41+
self.assertTrue(os.path.exists(int8_model_path),
42+
'INT8 IR model is not found, please set your own model path!')
43+
int8_model = compile(int8_model_path)
44+
int8_output_dict = int8_model.inference([input_0, input_1, input_2])
45+
int8_output = list(int8_output_dict.values())[0]
46+
# sparse graph tuning
47+
int8_model.graph_dispatch(inputs_shape = [shape, shape, shape])
48+
int8_dispatch_output_dict = int8_model.inference([input_0, input_1, input_2])
49+
int8_dispatch_output = list(int8_dispatch_output_dict.values())[0]
50+
# compare outputs
51+
self.assertTrue((int8_output == int8_dispatch_output).all())
52+
53+
# validate onednn graph tuning
54+
fp32_model_path = "/home/tensorflow/inc_ut/engine/bert_mini_sst2_1x4_fp32.onnx"
55+
self.assertTrue(os.path.exists(fp32_model_path),
56+
'FP32 ONNX model is not found, please set your own model path!')
57+
fp32_model = compile(fp32_model_path)
58+
fp32_output_dict = fp32_model.inference([input_0, input_1, input_2])
59+
fp32_output = list(fp32_output_dict.values())[0]
60+
# onednn graph tuning
61+
fp32_model.graph_dispatch(inputs_shape = [shape, shape, shape])
62+
fp32_dispatch_output_dict = fp32_model.inference([input_0, input_1, input_2])
63+
fp32_dispatch_output = list(fp32_dispatch_output_dict.values())[0]
64+
# compare outputs
65+
self.assertTrue((fp32_output == fp32_dispatch_output).all())
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)