Skip to content

Commit 58317fc

Browse files
authored
Support bf16 onnx model (#141)
* Support bf16 onnx model * fix the tensor init api * fix the trainer bf16 onnx gen problem * remove neural_engine depence * update the code * remove the debug * remove all INC dependence * update the code * short the line to 99
1 parent a41d8a9 commit 58317fc

File tree

16 files changed

+427
-75
lines changed

16 files changed

+427
-75
lines changed

nlp_toolkit/backends/neural_engine/compile/compile.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,11 @@
2323
4. Finally, convert them to .yaml file and .bin file for model configuration and inference.
2424
"""
2525

26-
import datetime
27-
import logging as log
28-
import os
29-
import sys
30-
import traceback
3126
from collections import OrderedDict
3227
from .loaders.loader import Loader
3328
from .extractors.extractor import Extractor
3429
from .sub_graph.subgraph_matcher import SubGraphMatcher
30+
from .graph_utils import get_model_fwk_name
3531

3632
COMPILES = OrderedDict({
3733
'loader': Loader,
@@ -53,5 +49,10 @@ def start_pipeline(model, config=None):
5349

5450

5551
def compile(model, config=None):
56-
model = start_pipeline(model, config=config)
52+
if get_model_fwk_name(model) == 'neural engine':
53+
from nlp_toolkit.backends.neural_engine.compile.graph import Graph
54+
model = Graph()
55+
model.graph_init(model + '/conf.yaml', model + '/model.bin')
56+
else:
57+
model = start_pipeline(model, config=config)
5758
return model

nlp_toolkit/backends/neural_engine/compile/extractors/extractor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
from .tf_extractor import TensorflowExtractor
1919
from .onnx_extractor import ONNXExtractor
20-
from neural_compressor.utils import logger
21-
20+
from .. import logger
21+
from ..graph_utils import get_model_fwk_name
2222

2323
EXTRACTORS = {
2424
'tensorflow': TensorflowExtractor,
@@ -33,9 +33,8 @@ class Extractor(object):
3333
"""
3434

3535
def __call__(self, model):
36-
# framework = model.framework_specific_info['framework']
37-
framework = model.framework()
36+
framework = model[1]
3837
extractor = EXTRACTORS[framework]()
39-
model = extractor(model)
38+
model = extractor(model[0])
4039
logger.info('Extract {} model done...'.format(framework))
4140
return model

nlp_toolkit/backends/neural_engine/compile/extractors/onnx_extractor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717

1818

19-
from neural_compressor.utils import logger
19+
from .. import logger
2020
from ..graph.graph import Graph
2121
from ..ops.op import OPERATORS
2222
from ..onnx_utils import graph_node_names_details
@@ -30,7 +30,7 @@ class ONNXExtractor(object):
3030
and output_tensors, these tensors record the source/dest op name. All of these nodes
3131
(in a list) will compose a graph, which is Graph class, as the return object.
3232
Args:
33-
model: neural_compressor TensorflowBaseModel
33+
model: ONNXModel
3434
Return:
3535
Graph: Graph class, the new graph object
3636
@@ -42,12 +42,12 @@ def __call__(self, model):
4242
logger.info('Start to extarct onnx model ops...')
4343
new_graph = Graph()
4444
new_graph.framework = 'onnxruntime'
45-
for graph_input in model.graph().input:
45+
for graph_input in model.graph.input:
4646
op_type = 'ONNXINPUT'
4747
new_node = OPERATORS[op_type]()
4848
new_node.extract('onnxruntime', graph_input, model, graph_nodes_dict)
4949
new_graph.insert_nodes(len(new_graph.nodes), [new_node])
50-
for node in model.nodes():
50+
for node in model.graph.node:
5151
op_type = node.op_type
5252
if op_type == 'Constant':
5353
continue
@@ -77,7 +77,7 @@ def __call__(self, model):
7777
for tensor in graph_node.input_tensors:
7878
if origin_tensor_name + ':0' == tensor.name:
7979
has_tensor = False
80-
if pre_node in model.initializer() and has_tensor:
80+
if pre_node in model.graph.initializer and has_tensor:
8181
from onnx.numpy_helper import to_array
8282
data = to_array(pre_node)
8383
shape = list(data.shape) if data.shape != () else [1]

nlp_toolkit/backends/neural_engine/compile/extractors/tf_extractor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717

1818

19-
from neural_compressor.utils import logger
19+
from .. import logger
2020
from ..graph.graph import Graph
2121
from ..ops.op import OPERATORS
2222
from ..tf_utils import graph_node_names_details
@@ -29,14 +29,13 @@ class TensorflowExtractor(object):
2929
and output_tensors, these tensors record the source/dest op name. All of these nodes
3030
(in a list) will compose a graph, which is Graph class, as the return object.
3131
Args:
32-
model: neural_compressor TensorflowBaseModel
32+
model: TensorflowModel
3333
Return:
3434
Graph: Graph class, the new graph object
3535
3636
"""
3737
@classmethod
3838
def __call__(self, model):
39-
4039
nodes = model.graph_def.node
4140
graph_nodes_dict = graph_node_names_details(nodes)
4241
logger.info('Start to extarct tensorflow model ops...')

nlp_toolkit/backends/neural_engine/compile/graph/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import re
1919
from collections import OrderedDict
20-
from neural_compressor.utils import logger
20+
from .. import logger
2121
import numpy as np
2222
import yaml
2323
import os

nlp_toolkit/backends/neural_engine/compile/graph_utils.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from neural_compressor.utils import logger
18+
from . import logger
1919
import copy
2020
import re
21+
import os
2122
import numpy as np
2223
from collections import namedtuple, OrderedDict
2324
from schema import Schema, And, Or
25+
import importlib
2426

2527

2628
DTYPES_DICT = {"float16": "fp16",
@@ -941,3 +943,101 @@ def pattern_mapping_conf_validation(conf_dict):
941943
}, ignore_extra_keys=True)
942944

943945
return dict_schema.validate(conf_dict)
946+
947+
class LazyImport(object):
948+
"""Lazy import python module till use
949+
Args:
950+
module_name (string): The name of module imported later
951+
"""
952+
def __init__(self, module_name):
953+
self.module_name = module_name
954+
self.module = None
955+
956+
def __getattr__(self, name):
957+
try:
958+
self.module = importlib.import_module(self.module_name)
959+
mod = getattr(self.module, name)
960+
except:
961+
spec = importlib.util.find_spec(str(self.module_name + '.' + name))
962+
mod = importlib.util.module_from_spec(spec)
963+
spec.loader.exec_module(mod)
964+
return mod
965+
966+
def __call__(self, *args, **kwargs):
967+
function_name = self.module_name.split('.')[-1]
968+
module_name = self.module_name.split(f'.{function_name}')[0]
969+
self.module = importlib.import_module(module_name)
970+
function = getattr(self.module, function_name)
971+
return function(*args, **kwargs)
972+
973+
def get_model_fwk_name(model):
974+
"""Detect the input model belongs to which framework
975+
Args:
976+
model (string): framework name that supported by Neural Engine,
977+
if there's no available fwk info, then return 'NA'.
978+
"""
979+
onnx = LazyImport('onnx')
980+
tf = LazyImport('tensorflow')
981+
def _is_onnxruntime(model):
982+
try:
983+
if isinstance(model, str):
984+
graph = onnx.load(model)
985+
assert(len(graph.graph.node) != 0)
986+
else:
987+
graph = model.graph
988+
except:
989+
pass
990+
else:
991+
return 'onnxruntime'
992+
return 'NA'
993+
994+
def _is_tensorflow(model):
995+
try:
996+
if isinstance(model, str):
997+
graph_def = tf.compat.v1.GraphDef()
998+
with open(model, 'rb') as f:
999+
graph_def.ParseFromString(f.read())
1000+
else:
1001+
graph = model.graph_def
1002+
except:
1003+
pass
1004+
else:
1005+
return 'tensorflow'
1006+
return 'NA'
1007+
1008+
def _is_neural_engine(model):
1009+
if model and os.path.isdir(model):
1010+
file_list = os.listdir(model)
1011+
is_engine = True
1012+
if len(file_list) == 2:
1013+
for file_name in file_list:
1014+
file_ext= os.path.splitext(file_name)
1015+
front, ext = file_ext
1016+
if ext == ".yaml":
1017+
is_engine &= True
1018+
elif ext == ".bin":
1019+
is_engine &= True
1020+
else:
1021+
is_engine &= False
1022+
logger.error("Please Input yaml and bin for neural engine.")
1023+
return 'NA'
1024+
else:
1025+
return 'NA'
1026+
if is_engine == True:
1027+
return 'neural engine'
1028+
else:
1029+
return 'NA'
1030+
1031+
if isinstance(model, str):
1032+
absmodel = os.path.abspath(os.path.expanduser(model))
1033+
assert os.path.exists(absmodel) or os.path.exists(absmodel+'.pb'), \
1034+
'invalid input path, the file does not exist!'
1035+
1036+
checker = [_is_onnxruntime, _is_neural_engine, _is_tensorflow]
1037+
for handler in checker:
1038+
fwk_name = handler(model)
1039+
if fwk_name != 'NA':
1040+
break
1041+
assert fwk_name != 'NA', 'Framework is not detected correctly from model format.'
1042+
1043+
return fwk_name

nlp_toolkit/backends/neural_engine/compile/loaders/loader.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,25 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from neural_compressor.model.model import MODELS, get_model_fwk_name, get_model_type
19-
from neural_compressor.utils.utility import LazyImport
18+
from ..graph_utils import LazyImport, get_model_fwk_name
19+
2020
onnx = LazyImport('onnx')
21-
21+
tf = LazyImport('tensorflow')
22+
2223
class Loader(object):
2324
def __call__(self, model):
2425
framework = get_model_fwk_name(model)
2526
if framework =='tensorflow':
26-
model_type = get_model_type(model)
27-
model =MODELS[framework](model_type, model)
27+
if isinstance(model, str):
28+
graph = tf.Graph()
29+
graph_def = tf.compat.v1.GraphDef()
30+
with open(model, 'rb') as f:
31+
graph_def.ParseFromString(f.read())
32+
with graph.as_default():
33+
tf.import_graph_def(graph_def, name='')
34+
config = tf.compat.v1.ConfigProto()
35+
model = tf.compat.v1.Session(graph=graph, config=config)
2836
if framework =='onnxruntime':
29-
model = onnx.load(model)
30-
model =MODELS[framework](model)
31-
return model
37+
if isinstance(model, str):
38+
model = onnx.load(model)
39+
return model, framework

0 commit comments

Comments
 (0)