Skip to content

Commit a3e399a

Browse files
authored
Merge pull request #409 from dyashuni/cpp_windows_tests
Add cpp tests for Windows in CI
2 parents 6fa8cd0 + e8b3e44 commit a3e399a

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

.github/workflows/build.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ jobs:
2222
run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py"
2323

2424
test_cpp:
25-
runs-on: ubuntu-latest
25+
runs-on: ${{matrix.os}}
26+
strategy:
27+
matrix:
28+
os: [ubuntu-latest, windows-latest]
2629
steps:
2730
- uses: actions/checkout@v3
2831
- uses: actions/setup-python@v4
@@ -34,17 +37,27 @@ jobs:
3437
mkdir build
3538
cd build
3639
cmake ..
37-
make
40+
if [ "$RUNNER_OS" == "Linux" ]; then
41+
make
42+
elif [ "$RUNNER_OS" == "Windows" ]; then
43+
cmake --build ./ --config Release
44+
fi
45+
shell: bash
3846

3947
- name: Prepare test data
4048
run: |
4149
pip install numpy
4250
cd examples
4351
python update_gen_data.py
52+
shell: bash
4453

4554
- name: Test
4655
run: |
4756
cd build
57+
if [ "$RUNNER_OS" == "Windows" ]; then
58+
cp ./Release/* ./
59+
fi
4860
./searchKnnCloserFirst_test
4961
./test_updates
5062
./test_updates update
63+
shell: bash

examples/updates_test.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "../hnswlib/hnswlib.h"
22
#include <thread>
3+
4+
35
class StopW
46
{
57
std::chrono::steady_clock::time_point time_begin;
@@ -22,6 +24,7 @@ class StopW
2224
}
2325
};
2426

27+
2528
/*
2629
* replacement for the openmp '#pragma omp parallel for' directive
2730
* only handles a subset of functionality (no reductions etc)
@@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
8184
std::rethrow_exception(lastException);
8285
}
8386
}
84-
85-
8687
}
8788

8889

@@ -94,7 +95,7 @@ std::vector<datatype> load_batch(std::string path, int size)
9495
assert(sizeof(datatype) == 4);
9596

9697
std::ifstream file;
97-
file.open(path);
98+
file.open(path, std::ios::binary);
9899
if (!file.is_open())
99100
{
100101
std::cout << "Cannot open " << path << "\n";
@@ -107,6 +108,7 @@ std::vector<datatype> load_batch(std::string path, int size)
107108
return batch;
108109
}
109110

111+
110112
template <typename d_type>
111113
static float
112114
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
@@ -137,6 +139,7 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
137139
return 1.0f * correct / total;
138140
}
139141

142+
140143
static void
141144
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
142145
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
@@ -155,6 +158,8 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
155158
efs.push_back(i);
156159
}
157160
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";
161+
162+
bool test_passed = false;
158163
for (size_t ef : efs)
159164
{
160165
appr_alg.setEf(ef);
@@ -171,20 +176,24 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
171176
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
172177
if (recall > 0.99)
173178
{
179+
test_passed = true;
174180
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
175181
break;
176182
}
177183
}
184+
if (!test_passed)
185+
{
186+
std::cerr << "Test failed\n";
187+
exit(1);
188+
}
178189
}
179190

191+
180192
int main(int argc, char **argv)
181193
{
182-
183194
int M = 16;
184195
int efConstruction = 200;
185196
int num_threads = std::thread::hardware_concurrency();
186-
187-
188197

189198
bool update = false;
190199

@@ -207,7 +216,6 @@ int main(int argc, char **argv)
207216

208217
std::string path = "../examples/data/";
209218

210-
211219
int N;
212220
int dummy_data_multiplier;
213221
int N_queries;
@@ -240,7 +248,6 @@ int main(int argc, char **argv)
240248
if (update)
241249
{
242250
std::cout << "Update iteration 0\n";
243-
244251

245252
ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
246253
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
@@ -295,4 +302,4 @@ int main(int argc, char **argv)
295302
}
296303

297304
return 0;
298-
};
305+
};

0 commit comments

Comments
 (0)