Skip to content

Commit bd48c7f

Browse files
committed
Support symbolic shape inference
Signed-off-by: inisis <[email protected]>
1 parent 24096c6 commit bd48c7f

File tree

2 files changed

+3198
-6
lines changed

2 files changed

+3198
-6
lines changed

src/onnx_ir/passes/common/shape_inference.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import onnx_ir as ir
1717
from onnx_ir.passes.common import _c_api_utils
18+
from .symbolic_shape_infer import SymbolicShapeInference
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -72,12 +73,15 @@ def __init__(
7273

7374
def call(self, model: ir.Model) -> ir.passes.PassResult:
7475
def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
75-
return onnx.shape_inference.infer_shapes(
76-
proto,
77-
check_type=self.check_type,
78-
strict_mode=self.strict_mode,
79-
data_prop=self.data_prop,
80-
)
76+
try:
77+
return SymbolicShapeInference.infer_shapes(proto, auto_merge=True)
78+
except Exception:
79+
return onnx.shape_inference.infer_shapes(
80+
proto,
81+
check_type=self.check_type,
82+
strict_mode=self.strict_mode,
83+
data_prop=self.data_prop,
84+
)
8185

8286
try:
8387
inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)

0 commit comments

Comments
 (0)