Skip to content

Commit a286751

Browse files
authored
Added support for opset7::Gather (opencv#126)
1 parent 3e516d9 commit a286751

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

modules/arm_plugin/src/arm_converter/arm_converter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Converter::Converter(const std::shared_ptr<const ngraph::Function> function, boo
114114
Register<opset::Concat>();
115115
Register<opset::Transpose>();
116116
Register<opset::StridedSlice>();
117+
Register<opset::Gather>();
117118
Register<ngraph::op::v1::Gather>();
118119
Register<opset::ROIPooling>();
119120
Register<opset::PSROIPooling>();

modules/arm_plugin/src/arm_converter/arm_converter_gather.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,25 @@
88
#include "arm_converter/arm_converter.hpp"
99

1010
namespace ArmPlugin {
11+
template <> Converter::Conversion::Ptr Converter::Convert(const opset::Gather& node) {
12+
auto make = [&] (auto refFunction) {
13+
return this->MakeConversion(refFunction,
14+
node.input(0),
15+
node.input(1),
16+
node.output(0),
17+
node.get_input_shape(0),
18+
node.get_input_shape(1),
19+
node.get_output_shape(0),
20+
static_cast<size_t>(node.get_axis()),
21+
static_cast<size_t>(node.get_batch_dims()));
22+
};
23+
24+
return CallSwitch(
25+
AP_WRAP(make, ngraph::runtime::reference::gather),
26+
node.input(0), allTypes,
27+
node.input(1), indexTypes);
28+
}
29+
1130
template <> Converter::Conversion::Ptr Converter::Convert(const ngraph::op::v1::Gather& node) {
1231
auto make = [&] (auto refFunction) {
1332
return this->MakeConversion(refFunction,

modules/arm_plugin/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,34 @@ const auto params_ref = testing::Combine(
118118
);
119119

120120
INSTANTIATE_TEST_CASE_P(smoke_Gather_refernce, GatherLayerTest, params_ref, GatherLayerTest::getTestCaseName);
121+
122+
const std::vector<std::vector<size_t>> indicesShapes5 = {
123+
std::vector<size_t>{10, 4},
124+
std::vector<size_t>{10, 20, 5},
125+
};
126+
127+
const std::vector< std::tuple<int, int> > axes_batches = {
128+
std::tuple<int, int>(0, 0),
129+
std::tuple<int, int>(1, 0),
130+
std::tuple<int, int>(2, 0),
131+
std::tuple<int, int>(3, 0),
132+
std::tuple<int, int>(-1, 0),
133+
std::tuple<int, int>(-2, 0),
134+
std::tuple<int, int>(1, 1),
135+
std::tuple<int, int>(-1, 1),
136+
};
137+
138+
const auto params_g7 = testing::Combine(
139+
testing::ValuesIn(inputShapes),
140+
testing::ValuesIn(indicesShapes5),
141+
testing::ValuesIn(axes_batches),
142+
testing::ValuesIn(netPrecisions),
143+
testing::Values(InferenceEngine::Precision::UNSPECIFIED),
144+
testing::Values(InferenceEngine::Precision::UNSPECIFIED),
145+
testing::Values(InferenceEngine::Layout::ANY),
146+
testing::Values(InferenceEngine::Layout::ANY),
147+
testing::Values(CommonTestUtils::DEVICE_CPU)
148+
);
149+
150+
INSTANTIATE_TEST_CASE_P(smoke_V7Gather4, Gather7LayerTest, params_g7, Gather7LayerTest::getTestCaseName);
121151
} // namespace

0 commit comments

Comments
 (0)