Skip to content

Commit 317e7a8

Browse files
committed
fix bug
Signed-off-by: Xin He <[email protected]>
1 parent 3173945 commit 317e7a8

File tree

2 files changed

+72
-48
lines changed

2 files changed

+72
-48
lines changed

neural_compressor/experimental/export/torch2onnx.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -30,54 +30,6 @@
3030
ortq = LazyImport('onnxruntime.quantization')
3131

3232

33-
def ONNX2Numpy_dtype(onnx_node_type):
34-
"""Get Numpy data type from onnx data type.
35-
36-
Args:
37-
onnx_node_type (str): data type description.
38-
39-
Returns:
40-
dtype: numpy data type
41-
"""
42-
# Only record sepcial data type
43-
ONNX2Numpy_dtype_mapping = {
44-
"tensor(float)": np.float32,
45-
"tensor(double)": np.float64,
46-
}
47-
if onnx_node_type in ONNX2Numpy_dtype_mapping:
48-
dtype = ONNX2Numpy_dtype_mapping[onnx_node_type]
49-
return dtype
50-
else:
51-
tmp = onnx_node_type.lstrip('tensor(').rstrip(')')
52-
dtype = eval(f'np.{tmp}')
53-
return dtype
54-
55-
56-
class DummyDataReader(ortq.CalibrationDataReader):
57-
"""Build dummy datareader for onnx static quantization."""
58-
59-
def __init__(self, fp32_onnx_path):
60-
"""Initialize data reader.
61-
62-
Args:
63-
fp32_onnx_path (str): path to onnx file
64-
"""
65-
session = ort.InferenceSession(fp32_onnx_path, None)
66-
input_tensors = session.get_inputs()
67-
input = {}
68-
for node in input_tensors:
69-
shape = []
70-
for dim in node.shape:
71-
shape.append(dim if isinstance(dim, int) else 1)
72-
dtype = ONNX2Numpy_dtype(node.type)
73-
input[node.name] = np.ones(shape).astype(dtype)
74-
self.data = [input]
75-
self.data = iter(self.data)
76-
def get_next(self):
77-
"""Generate next data."""
78-
return next(self.data, None)
79-
80-
8133
def update_weight_bias(
8234
int8_model,
8335
fp32_onnx_path,
@@ -469,6 +421,7 @@ def torch_to_int8_onnx(
469421
)
470422

471423
else:
424+
from .utils import DummyDataReader
472425
dummy_datareader = DummyDataReader(fp32_onnx_path)
473426
ortq.quantize_static(
474427
fp32_onnx_path,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2021 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import numpy as np
19+
from neural_compressor.utils.utility import LazyImport
20+
21+
ort = LazyImport('onnxruntime')
22+
ortq = LazyImport('onnxruntime.quantization')
23+
24+
25+
def ONNX2Numpy_dtype(onnx_node_type):
26+
"""Get Numpy data type from onnx data type.
27+
28+
Args:
29+
onnx_node_type (str): data type description.
30+
31+
Returns:
32+
dtype: numpy data type
33+
"""
34+
# Only record sepcial data type
35+
ONNX2Numpy_dtype_mapping = {
36+
"tensor(float)": np.float32,
37+
"tensor(double)": np.float64,
38+
}
39+
if onnx_node_type in ONNX2Numpy_dtype_mapping:
40+
dtype = ONNX2Numpy_dtype_mapping[onnx_node_type]
41+
return dtype
42+
else:
43+
tmp = onnx_node_type.lstrip('tensor(').rstrip(')')
44+
dtype = eval(f'np.{tmp}')
45+
return dtype
46+
47+
48+
class DummyDataReader(ortq.CalibrationDataReader):
49+
"""Build dummy datareader for onnx static quantization."""
50+
51+
def __init__(self, fp32_onnx_path):
52+
"""Initialize data reader.
53+
54+
Args:
55+
fp32_onnx_path (str): path to onnx file
56+
"""
57+
session = ort.InferenceSession(fp32_onnx_path, None)
58+
input_tensors = session.get_inputs()
59+
input = {}
60+
for node in input_tensors:
61+
shape = []
62+
for dim in node.shape:
63+
shape.append(dim if isinstance(dim, int) else 1)
64+
dtype = ONNX2Numpy_dtype(node.type)
65+
input[node.name] = np.ones(shape).astype(dtype)
66+
self.data = [input]
67+
self.data = iter(self.data)
68+
69+
def get_next(self):
70+
"""Generate next data."""
71+
return next(self.data, None)

0 commit comments

Comments
 (0)