File tree Expand file tree Collapse file tree 2 files changed +3198
-6
lines changed
src/onnx_ir/passes/common Expand file tree Collapse file tree 2 files changed +3198
-6
lines changed Original file line number Diff line number Diff line change 15
15
16
16
import onnx_ir as ir
17
17
from onnx_ir .passes .common import _c_api_utils
18
+ from .symbolic_shape_infer import SymbolicShapeInference
18
19
19
20
logger = logging .getLogger (__name__ )
20
21
@@ -72,12 +73,15 @@ def __init__(
72
73
73
74
def call (self , model : ir .Model ) -> ir .passes .PassResult :
74
75
def partial_infer_shapes (proto : onnx .ModelProto ) -> onnx .ModelProto :
75
- return onnx .shape_inference .infer_shapes (
76
- proto ,
77
- check_type = self .check_type ,
78
- strict_mode = self .strict_mode ,
79
- data_prop = self .data_prop ,
80
- )
76
+ try :
77
+ return SymbolicShapeInference .infer_shapes (proto , auto_merge = True )
78
+ except Exception :
79
+ return onnx .shape_inference .infer_shapes (
80
+ proto ,
81
+ check_type = self .check_type ,
82
+ strict_mode = self .strict_mode ,
83
+ data_prop = self .data_prop ,
84
+ )
81
85
82
86
try :
83
87
inferred_model_proto = _c_api_utils .call_onnx_api (partial_infer_shapes , model )
You can’t perform that action at this time.
0 commit comments