1
1
#include " ../hnswlib/hnswlib.h"
2
2
#include < thread>
3
+
4
+
3
5
class StopW
4
6
{
5
7
std::chrono::steady_clock::time_point time_begin;
@@ -22,6 +24,7 @@ class StopW
22
24
}
23
25
};
24
26
27
+
25
28
/*
26
29
* replacement for the openmp '#pragma omp parallel for' directive
27
30
* only handles a subset of functionality (no reductions etc)
@@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
81
84
std::rethrow_exception (lastException);
82
85
}
83
86
}
84
-
85
-
86
87
}
87
88
88
89
@@ -94,7 +95,7 @@ std::vector<datatype> load_batch(std::string path, int size)
94
95
assert (sizeof (datatype) == 4 );
95
96
96
97
std::ifstream file;
97
- file.open (path);
98
+ file.open (path, std::ios::binary );
98
99
if (!file.is_open ())
99
100
{
100
101
std::cout << " Cannot open " << path << " \n " ;
@@ -107,6 +108,7 @@ std::vector<datatype> load_batch(std::string path, int size)
107
108
return batch;
108
109
}
109
110
111
+
110
112
template <typename d_type>
111
113
static float
112
114
test_approx (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
@@ -137,6 +139,7 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
137
139
return 1 .0f * correct / total;
138
140
}
139
141
142
+
140
143
static void
141
144
test_vs_recall (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<float > &appr_alg, size_t vecdim,
142
145
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
@@ -155,6 +158,8 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
155
158
efs.push_back (i);
156
159
}
157
160
std::cout << " ef\t recall\t time\t hops\t distcomp\n " ;
161
+
162
+ bool test_passed = false ;
158
163
for (size_t ef : efs)
159
164
{
160
165
appr_alg.setEf (ef);
@@ -171,20 +176,24 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
171
176
std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
172
177
if (recall > 0.99 )
173
178
{
179
+ test_passed = true ;
174
180
std::cout << " Recall is over 0.99! " <<recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
175
181
break ;
176
182
}
177
183
}
184
+ if (!test_passed)
185
+ {
186
+ std::cerr << " Test failed\n " ;
187
+ exit (1 );
188
+ }
178
189
}
179
190
191
+
180
192
int main (int argc, char **argv)
181
193
{
182
-
183
194
int M = 16 ;
184
195
int efConstruction = 200 ;
185
196
int num_threads = std::thread::hardware_concurrency ();
186
-
187
-
188
197
189
198
bool update = false ;
190
199
@@ -207,7 +216,6 @@ int main(int argc, char **argv)
207
216
208
217
std::string path = " ../examples/data/" ;
209
218
210
-
211
219
int N;
212
220
int dummy_data_multiplier;
213
221
int N_queries;
@@ -240,7 +248,6 @@ int main(int argc, char **argv)
240
248
if (update)
241
249
{
242
250
std::cout << " Update iteration 0\n " ;
243
-
244
251
245
252
ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
246
253
appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
@@ -295,4 +302,4 @@ int main(int argc, char **argv)
295
302
}
296
303
297
304
return 0 ;
298
- };
305
+ };
0 commit comments