Skip to content

Commit 00b434f

Browse files
committed
Remove some code duplication in bindings
1 parent 6d28ec0 commit 00b434f

File tree

1 file changed

+51
-75
lines changed

1 file changed

+51
-75
lines changed

python_bindings/bindings.cpp

Lines changed: 51 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
4242
while (true) {
4343
size_t id = current.fetch_add(1);
4444

45-
if ((id >= end)) {
45+
if (id >= end) {
4646
break;
4747
}
4848

@@ -79,6 +79,45 @@ inline void assert_true(bool expr, const std::string & msg) {
7979
}
8080

8181

82+
inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
83+
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
84+
if (buffer.ndim == 2) {
85+
*rows = buffer.shape[0];
86+
*features = buffer.shape[1];
87+
} else {
88+
*rows = 1;
89+
*features = buffer.shape[0];
90+
}
91+
}
92+
93+
94+
inline std::vector<size_t> get_input_ids_and_check_shapes(const py::object& ids_, size_t rows) {
95+
std::vector<size_t> ids;
96+
if (!ids_.is_none()) {
97+
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
98+
auto ids_numpy = items.request();
99+
// check shapes
100+
bool valid = false;
101+
if ((ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) || (ids_numpy.ndim == 0 && rows == 1)) {
102+
valid = true;
103+
}
104+
if (!valid) throw std::runtime_error("wrong dimensionality of the labels");
105+
// extract data
106+
if (ids_numpy.ndim == 1) {
107+
std::vector<size_t> ids1(ids_numpy.shape[0]);
108+
for (size_t i = 0; i < ids1.size(); i++) {
109+
ids1[i] = items.data()[i];
110+
}
111+
ids.swap(ids1);
112+
} else if (ids_numpy.ndim == 0) {
113+
ids.push_back(*items.data());
114+
}
115+
}
116+
117+
return ids;
118+
}
119+
120+
82121
template<typename dist_t, typename data_t = float>
83122
class Index {
84123
public:
@@ -146,7 +185,7 @@ class Index {
146185
void set_ef(size_t ef) {
147186
default_ef = ef;
148187
if (appr_alg)
149-
appr_alg->ef_ = ef;
188+
appr_alg->ef_ = ef;
150189
}
151190

152191

@@ -188,15 +227,7 @@ class Index {
188227
num_threads = num_threads_default;
189228

190229
size_t rows, features;
191-
192-
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
193-
if (buffer.ndim == 2) {
194-
rows = buffer.shape[0];
195-
features = buffer.shape[1];
196-
} else {
197-
rows = 1;
198-
features = buffer.shape[0];
199-
}
230+
get_input_array_shapes(buffer, &rows, &features);
200231

201232
if (features != dim)
202233
throw std::runtime_error("wrong dimensionality of the vectors");
@@ -206,23 +237,7 @@ class Index {
206237
num_threads = 1;
207238
}
208239

209-
std::vector<size_t> ids;
210-
211-
if (!ids_.is_none()) {
212-
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
213-
auto ids_numpy = items.request();
214-
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
215-
std::vector<size_t> ids1(ids_numpy.shape[0]);
216-
for (size_t i = 0; i < ids1.size(); i++) {
217-
ids1[i] = items.data()[i];
218-
}
219-
ids.swap(ids1);
220-
} else if (ids_numpy.ndim == 0 && rows == 1) {
221-
ids.push_back(*items.data());
222-
} else {
223-
throw std::runtime_error("wrong dimensionality of the labels");
224-
}
225-
}
240+
std::vector<size_t> ids = get_input_ids_and_check_shapes(ids_, rows);
226241

227242
{
228243
int start = 0;
@@ -561,15 +576,7 @@ class Index {
561576

562577
{
563578
py::gil_scoped_release l;
564-
565-
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
566-
if (buffer.ndim == 2) {
567-
rows = buffer.shape[0];
568-
features = buffer.shape[1];
569-
} else {
570-
rows = 1;
571-
features = buffer.shape[0];
572-
}
579+
get_input_array_shapes(buffer, &rows, &features);
573580

574581
// avoid using threads when the number of searches is small:
575582
if (rows <= num_threads * 4) {
@@ -725,36 +732,12 @@ class BFIndex {
725732
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
726733
auto buffer = items.request();
727734
size_t rows, features;
728-
729-
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
730-
if (buffer.ndim == 2) {
731-
rows = buffer.shape[0];
732-
features = buffer.shape[1];
733-
} else {
734-
rows = 1;
735-
features = buffer.shape[0];
736-
}
735+
get_input_array_shapes(buffer, &rows, &features);
737736

738737
if (features != dim)
739738
throw std::runtime_error("wrong dimensionality of the vectors");
740739

741-
std::vector<size_t> ids;
742-
743-
if (!ids_.is_none()) {
744-
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
745-
auto ids_numpy = items.request();
746-
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
747-
std::vector<size_t> ids1(ids_numpy.shape[0]);
748-
for (size_t i = 0; i < ids1.size(); i++) {
749-
ids1[i] = items.data()[i];
750-
}
751-
ids.swap(ids1);
752-
} else if (ids_numpy.ndim == 0 && rows == 1) {
753-
ids.push_back(*items.data());
754-
} else {
755-
throw std::runtime_error("wrong dimensionality of the labels");
756-
}
757-
}
740+
std::vector<size_t> ids = get_input_ids_and_check_shapes(ids_, rows);
758741

759742
{
760743
for (size_t row = 0; row < rows; row++) {
@@ -802,14 +785,7 @@ class BFIndex {
802785
{
803786
py::gil_scoped_release l;
804787

805-
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
806-
if (buffer.ndim == 2) {
807-
rows = buffer.shape[0];
808-
features = buffer.shape[1];
809-
} else {
810-
rows = 1;
811-
features = buffer.shape[0];
812-
}
788+
get_input_array_shapes(buffer, &rows, &features);
813789

814790
data_numpy_l = new hnswlib::labeltype[rows * k];
815791
data_numpy_d = new dist_t[rows * k];
@@ -836,14 +812,14 @@ class BFIndex {
836812

837813
return py::make_tuple(
838814
py::array_t<hnswlib::labeltype>(
839-
{rows, k}, // shape
840-
{k * sizeof(hnswlib::labeltype),
815+
{ rows, k }, // shape
816+
{ k * sizeof(hnswlib::labeltype),
841817
sizeof(hnswlib::labeltype)}, // C-style contiguous strides for each index
842818
data_numpy_l, // the data pointer
843819
free_when_done_l),
844820
py::array_t<dist_t>(
845-
{rows, k}, // shape
846-
{k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for each index
821+
{ rows, k }, // shape
822+
{ k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index
847823
data_numpy_d, // the data pointer
848824
free_when_done_d));
849825
}

0 commit comments

Comments
 (0)