Skip to content

Commit f64833d

Browse files
yuwenzhomengniwang95
authored andcommitted
Reduce memory consumption in ONNXRT adaptor (#1266)
* reduce memory consumption Signed-off-by: yuwenz <[email protected]>
1 parent 5ba9efe commit f64833d

File tree

4 files changed

+82
-19
lines changed

4 files changed

+82
-19
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -712,13 +712,7 @@ def _detect_domain(self, model):
712712
# 2. according to input
713713
# typically, NLP models have multiple inputs,
714714
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
715-
if not model.is_large_model:
716-
sess = ort.InferenceSession(model.model.SerializeToString(), providers=["CPUExecutionProvider"])
717-
elif model.model_path is not None: # pragma: no cover
718-
sess = ort.InferenceSession(model.model_path, providers=["CPUExecutionProvider"])
719-
else: # pragma: no cover
720-
assert False, "Please use model path instead of onnx model object to quantize."
721-
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
715+
input_shape_lens = [len(inp.type.tensor_type.shape.dim) for inp in model.model.graph.input]
722716
if len(input_shape_lens) > 1 and all(shape_len == 2 for shape_len in input_shape_lens):
723717
is_nlp = True
724718

@@ -778,11 +772,15 @@ def _pre_optimize(self, model, level=1):
778772

779773
sess_options.register_custom_ops_library(get_library_path())
780774
if not model.is_large_model:
781-
ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"])
775+
sess = ort.InferenceSession(
776+
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
777+
)
782778
elif model.model_path is not None: # pragma: no cover
783-
ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
779+
model.model = onnx.ModelProto() # clean memory for large model
780+
sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
784781
else: # pragma: no cover
785782
logger.warning("Please use model path instead of onnx model object to quantize")
783+
del sess
786784

787785
tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False)
788786

neural_compressor/adaptor/ox_utils/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
"DmlExecutionProvider": "onnxrt_dml_ep",
8686
}
8787

88+
MAXIMUM_PROTOBUF = 2147483648
89+
8890

8991
def dtype_to_name(dtype_mapping, dtype):
9092
"""Map data type and its string representation."""

neural_compressor/model/onnx_model.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
import logging
2020
import os
21+
import sys
2122
from pathlib import Path
2223

24+
from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF
2325
from neural_compressor.model.base_model import BaseModel
2426
from neural_compressor.utils.utility import LazyImport
2527

@@ -41,16 +43,9 @@ def __init__(self, model, **kwargs):
4143
"""
4244
self._model = model if not isinstance(model, str) else onnx.load(model)
4345
self._model_path = None if not isinstance(model, str) else model
44-
self._is_large_model = False
45-
try:
46-
ort.InferenceSession(self._model.SerializeToString(), providers=["CPUExecutionProvider"])
47-
except Exception as e: # pragma: no cover
48-
if self._model_path is not None:
49-
ort.InferenceSession(self._model_path, providers=["CPUExecutionProvider"])
50-
self._is_large_model = True
51-
else:
52-
logger.warning("Please use model path instead of onnx model object to quantize")
53-
46+
self._is_large_model = self.check_large_model()
47+
if self._is_large_model and self._model_path is None:
48+
logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")
5449
self._config = None
5550
if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
5651
from transformers import PretrainedConfig
@@ -66,6 +61,26 @@ def __init__(self, model, **kwargs):
6661
self._get_graph_info()
6762
self._q_config = None
6863

64+
def check_large_model(self):
65+
"""Check model > 2GB."""
66+
init_size = 0
67+
for init in self._model.graph.initializer:
68+
# if initializer has external data location, return True
69+
if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
70+
return True
71+
# if raise error of initializer size > 2GB, return True
72+
try:
73+
init_bytes = init.SerializeToString()
74+
init_size += sys.getsizeof(init_bytes)
75+
except Exception as e:
76+
if "exceeds maximum protobuf size of 2GB" in str(e):
77+
return True
78+
else: # pragma: no cover
79+
raise e
80+
if init_size > MAXIMUM_PROTOBUF:
81+
return True
82+
return False
83+
6984
@property
7085
def is_large_model(self):
7186
"""Check the onnx model is over 2GB."""

test/model/test_onnx_model.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def setUp(self):
203203
def tearDownClass(self):
204204
shutil.rmtree("./gptj", ignore_errors=True)
205205
shutil.rmtree("./hf_test", ignore_errors=True)
206+
os.remove("model.onnx")
206207

207208
def test_hf_model(self):
208209
from optimum.onnxruntime import ORTModelForCausalLM
@@ -407,6 +408,53 @@ def test_remove_unused_nodes(self):
407408
self.model.remove_unused_nodes()
408409
self.assertEqual(len(self.model.nodes()), 6)
409410

411+
def test_check_large_model(self):
412+
import onnx
413+
import torch
414+
import torch.nn as nn
415+
416+
from neural_compressor.model.onnx_model import ONNXModel
417+
418+
class Net(nn.Module):
419+
def __init__(self, in_features, out_features):
420+
super(Net, self).__init__()
421+
self.fc = nn.Linear(in_features, out_features)
422+
423+
def forward(self, x):
424+
x = self.fc(x)
425+
return x
426+
427+
# model > 2GB
428+
model = Net(512, 1024 * 1024)
429+
input = torch.randn(512, requires_grad=True)
430+
with torch.no_grad():
431+
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
432+
model = onnx.load("model.onnx")
433+
model = ONNXModel(model) # pass ModelProto
434+
self.assertTrue(model.check_large_model())
435+
436+
model = ONNXModel("model.onnx") # pass string
437+
self.assertTrue(model.check_large_model())
438+
439+
model = onnx.load("model.onnx", load_external_data=False) # not load init
440+
model = ONNXModel(model)
441+
self.assertTrue(model.check_large_model())
442+
443+
# model < 2GB
444+
model = Net(10, 10 * 10)
445+
input = torch.randn(10, requires_grad=True)
446+
with torch.no_grad():
447+
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
448+
model = onnx.load("model.onnx")
449+
model = ONNXModel(model) # pass ModelProto
450+
self.assertFalse(model.check_large_model())
451+
452+
model = ONNXModel("model.onnx") # pass string
453+
self.assertFalse(model.check_large_model())
454+
455+
model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init
456+
self.assertFalse(model.check_large_model())
457+
410458

411459
if __name__ == "__main__":
412460
unittest.main()

0 commit comments

Comments
 (0)