@@ -42,7 +42,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
42
42
while (true ) {
43
43
size_t id = current.fetch_add (1 );
44
44
45
- if (( id >= end) ) {
45
+ if (id >= end) {
46
46
break ;
47
47
}
48
48
@@ -79,6 +79,45 @@ inline void assert_true(bool expr, const std::string & msg) {
79
79
}
80
80
81
81
82
+ inline void get_input_array_shapes (const py::buffer_info& buffer, size_t * rows, size_t * features) {
83
+ if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
84
+ if (buffer.ndim == 2 ) {
85
+ *rows = buffer.shape [0 ];
86
+ *features = buffer.shape [1 ];
87
+ } else {
88
+ *rows = 1 ;
89
+ *features = buffer.shape [0 ];
90
+ }
91
+ }
92
+
93
+
94
+ inline std::vector<size_t > get_input_ids_and_check_shapes (const py::object& ids_, size_t rows) {
95
+ std::vector<size_t > ids;
96
+ if (!ids_.is_none ()) {
97
+ py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
98
+ auto ids_numpy = items.request ();
99
+ // check shapes
100
+ bool valid = false ;
101
+ if ((ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) || (ids_numpy.ndim == 0 && rows == 1 )) {
102
+ valid = true ;
103
+ }
104
+ if (!valid) throw std::runtime_error (" wrong dimensionality of the labels" );
105
+ // extract data
106
+ if (ids_numpy.ndim == 1 ) {
107
+ std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
108
+ for (size_t i = 0 ; i < ids1.size (); i++) {
109
+ ids1[i] = items.data ()[i];
110
+ }
111
+ ids.swap (ids1);
112
+ } else if (ids_numpy.ndim == 0 ) {
113
+ ids.push_back (*items.data ());
114
+ }
115
+ }
116
+
117
+ return ids;
118
+ }
119
+
120
+
82
121
template <typename dist_t , typename data_t = float >
83
122
class Index {
84
123
public:
@@ -146,7 +185,7 @@ class Index {
146
185
void set_ef (size_t ef) {
147
186
default_ef = ef;
148
187
if (appr_alg)
149
- appr_alg->ef_ = ef;
188
+ appr_alg->ef_ = ef;
150
189
}
151
190
152
191
@@ -188,15 +227,7 @@ class Index {
188
227
num_threads = num_threads_default;
189
228
190
229
size_t rows, features;
191
-
192
- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
193
- if (buffer.ndim == 2 ) {
194
- rows = buffer.shape [0 ];
195
- features = buffer.shape [1 ];
196
- } else {
197
- rows = 1 ;
198
- features = buffer.shape [0 ];
199
- }
230
+ get_input_array_shapes (buffer, &rows, &features);
200
231
201
232
if (features != dim)
202
233
throw std::runtime_error (" wrong dimensionality of the vectors" );
@@ -206,23 +237,7 @@ class Index {
206
237
num_threads = 1 ;
207
238
}
208
239
209
- std::vector<size_t > ids;
210
-
211
- if (!ids_.is_none ()) {
212
- py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
213
- auto ids_numpy = items.request ();
214
- if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
215
- std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
216
- for (size_t i = 0 ; i < ids1.size (); i++) {
217
- ids1[i] = items.data ()[i];
218
- }
219
- ids.swap (ids1);
220
- } else if (ids_numpy.ndim == 0 && rows == 1 ) {
221
- ids.push_back (*items.data ());
222
- } else {
223
- throw std::runtime_error (" wrong dimensionality of the labels" );
224
- }
225
- }
240
+ std::vector<size_t > ids = get_input_ids_and_check_shapes (ids_, rows);
226
241
227
242
{
228
243
int start = 0 ;
@@ -561,15 +576,7 @@ class Index {
561
576
562
577
{
563
578
py::gil_scoped_release l;
564
-
565
- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
566
- if (buffer.ndim == 2 ) {
567
- rows = buffer.shape [0 ];
568
- features = buffer.shape [1 ];
569
- } else {
570
- rows = 1 ;
571
- features = buffer.shape [0 ];
572
- }
579
+ get_input_array_shapes (buffer, &rows, &features);
573
580
574
581
// avoid using threads when the number of searches is small:
575
582
if (rows <= num_threads * 4 ) {
@@ -725,36 +732,12 @@ class BFIndex {
725
732
py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
726
733
auto buffer = items.request ();
727
734
size_t rows, features;
728
-
729
- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
730
- if (buffer.ndim == 2 ) {
731
- rows = buffer.shape [0 ];
732
- features = buffer.shape [1 ];
733
- } else {
734
- rows = 1 ;
735
- features = buffer.shape [0 ];
736
- }
735
+ get_input_array_shapes (buffer, &rows, &features);
737
736
738
737
if (features != dim)
739
738
throw std::runtime_error (" wrong dimensionality of the vectors" );
740
739
741
- std::vector<size_t > ids;
742
-
743
- if (!ids_.is_none ()) {
744
- py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
745
- auto ids_numpy = items.request ();
746
- if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
747
- std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
748
- for (size_t i = 0 ; i < ids1.size (); i++) {
749
- ids1[i] = items.data ()[i];
750
- }
751
- ids.swap (ids1);
752
- } else if (ids_numpy.ndim == 0 && rows == 1 ) {
753
- ids.push_back (*items.data ());
754
- } else {
755
- throw std::runtime_error (" wrong dimensionality of the labels" );
756
- }
757
- }
740
+ std::vector<size_t > ids = get_input_ids_and_check_shapes (ids_, rows);
758
741
759
742
{
760
743
for (size_t row = 0 ; row < rows; row++) {
@@ -802,14 +785,7 @@ class BFIndex {
802
785
{
803
786
py::gil_scoped_release l;
804
787
805
- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
806
- if (buffer.ndim == 2 ) {
807
- rows = buffer.shape [0 ];
808
- features = buffer.shape [1 ];
809
- } else {
810
- rows = 1 ;
811
- features = buffer.shape [0 ];
812
- }
788
+ get_input_array_shapes (buffer, &rows, &features);
813
789
814
790
data_numpy_l = new hnswlib::labeltype[rows * k];
815
791
data_numpy_d = new dist_t [rows * k];
@@ -836,14 +812,14 @@ class BFIndex {
836
812
837
813
return py::make_tuple (
838
814
py::array_t <hnswlib::labeltype>(
839
- {rows, k}, // shape
840
- {k * sizeof (hnswlib::labeltype),
815
+ { rows, k }, // shape
816
+ { k * sizeof (hnswlib::labeltype),
841
817
sizeof (hnswlib::labeltype)}, // C-style contiguous strides for each index
842
818
data_numpy_l, // the data pointer
843
819
free_when_done_l),
844
820
py::array_t <dist_t >(
845
- {rows, k}, // shape
846
- {k * sizeof (dist_t ), sizeof (dist_t )}, // C-style contiguous strides for each index
821
+ { rows, k }, // shape
822
+ { k * sizeof (dist_t ), sizeof (dist_t ) }, // C-style contiguous strides for each index
847
823
data_numpy_d, // the data pointer
848
824
free_when_done_d));
849
825
}
0 commit comments