Skip to content

Commit 6b53fe8

Browse files
authored
Merge pull request #23746 from Abdurrahheem:ash/graph_simplifier
Assertion Fix in Split Layer #23746 ### Pull Request Readiness Checklist This PR fixes issue mentioned in [#23663](#23663) Merge with opencv/opencv_extra#1067 See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
1 parent d3e7968 commit 6b53fe8

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,14 +1363,22 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP
13631363
{
13641364
DictValue splits = layerParams.get("split");
13651365
const int numSplits = splits.size();
1366-
CV_Assert(numSplits > 1);
13671366

1368-
std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
1369-
for (int i = 1; i < splits.size() - 1; ++i)
1367+
if (numSplits == 1)
13701368
{
1371-
slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i);
1369+
layerParams.set("num_split", 1);
1370+
}
1371+
else
1372+
{
1373+
CV_Assert(numSplits >= 1);
1374+
1375+
std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
1376+
for (int i = 1; i < splits.size() - 1; ++i)
1377+
{
1378+
slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i);
1379+
}
1380+
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
13721381
}
1373-
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
13741382
}
13751383
else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute.
13761384
{

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,7 @@ TEST_P(Test_ONNX_layers, Split)
11461146
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
11471147
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
11481148
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
1149+
testONNXModels("split_0");
11491150
testONNXModels("split_1");
11501151
testONNXModels("split_2");
11511152
testONNXModels("split_3");

0 commit comments

Comments
 (0)