Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ hnswlib.cpython*.so
var/
.idea/
.vscode/

.vs/
10 changes: 4 additions & 6 deletions examples/searchKnnCloserFirst_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
#include <vector>
#include <iostream>

namespace
{
namespace {

using idx_t = hnswlib::labeltype;

Expand All @@ -20,7 +19,7 @@ void test() {
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

Expand All @@ -34,7 +33,6 @@ void test() {
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}


hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
Expand Down Expand Up @@ -68,12 +66,12 @@ void test() {
gd.pop();
}
}

delete alg_brute;
delete alg_hnsw;
}

} // namespace
} // namespace

int main() {
std::cout << "Testing ..." << std::endl;
Expand Down
19 changes: 9 additions & 10 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
#include <vector>
#include <iostream>

namespace
{
namespace {

using idx_t = hnswlib::labeltype;

Expand All @@ -30,7 +29,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

Expand All @@ -46,8 +45,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
Expand Down Expand Up @@ -82,7 +81,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
gd.pop();
}
}

delete alg_brute;
delete alg_hnsw;
}
Expand All @@ -109,8 +108,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
Expand Down Expand Up @@ -140,12 +139,12 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
delete alg_hnsw;
}

} // namespace
} // namespace

class CustomFilterFunctor: public hnswlib::FilterFunctor {
std::unordered_set<unsigned int> allowed_values;

public:
public:
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {}

bool operator()(unsigned int id) {
Expand Down
125 changes: 49 additions & 76 deletions examples/updates_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
#include <thread>


class StopW
{
class StopW {
std::chrono::steady_clock::time_point time_begin;

public:
StopW()
{
public:
StopW() {
time_begin = std::chrono::steady_clock::now();
}

float getElapsedTimeMicro()
{
float getElapsedTimeMicro() {
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count());
}

void reset()
{
void reset() {
time_begin = std::chrono::steady_clock::now();
}
};
Expand Down Expand Up @@ -88,16 +84,14 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn


template <typename datatype>
std::vector<datatype> load_batch(std::string path, int size)
{
std::vector<datatype> load_batch(std::string path, int size) {
std::cout << "Loading " << path << "...";
// float or int32 (python)
assert(sizeof(datatype) == 4);

std::ifstream file;
file.open(path, std::ios::binary);
if (!file.is_open())
{
if (!file.is_open()) {
std::cout << "Cannot open " << path << "\n";
exit(1);
}
Expand All @@ -112,26 +106,17 @@ std::vector<datatype> load_batch(std::string path, int size)
template <typename d_type>
static float
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
{
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K) {
size_t correct = 0;
size_t total = 0;
//uncomment to test in parallel mode:


for (int i = 0; i < qsize; i++)
{

for (int i = 0; i < qsize; i++) {
std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K);
total += K;
while (result.size())
{
if (answers[i].find(result.top().second) != answers[i].end())
{
while (result.size()) {
if (answers[i].find(result.top().second) != answers[i].end()) {
correct++;
}
else
{
} else {
}
result.pop();
}
Expand All @@ -141,76 +126,70 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<


static void
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
{
test_vs_recall(
std::vector<float> &queries,
size_t qsize,
hnswlib::HierarchicalNSW<float> &appr_alg,
size_t vecdim,
std::vector<std::unordered_set<hnswlib::labeltype>> &answers,
size_t k) {

std::vector<size_t> efs = {1};
for (int i = k; i < 30; i++)
{
for (int i = k; i < 30; i++) {
efs.push_back(i);
}
for (int i = 30; i < 400; i+=10)
{
for (int i = 30; i < 400; i+=10) {
efs.push_back(i);
}
for (int i = 1000; i < 100000; i += 5000)
{
for (int i = 1000; i < 100000; i += 5000) {
efs.push_back(i);
}
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";

bool test_passed = false;
for (size_t ef : efs)
{
for (size_t ef : efs) {
appr_alg.setEf(ef);

appr_alg.metric_hops=0;
appr_alg.metric_distance_computations=0;
appr_alg.metric_hops = 0;
appr_alg.metric_distance_computations = 0;
StopW stopw = StopW();

float recall = test_approx<float>(queries, qsize, appr_alg, vecdim, answers, k);
float time_us_per_query = stopw.getElapsedTimeMicro() / qsize;
float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize);
float hops_per_query = appr_alg.metric_hops / (1.0f * qsize);

std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
if (recall > 0.99)
{
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n";
if (recall > 0.99) {
test_passed = true;
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
std::cout << "Recall is over 0.99! " << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n";
break;
}
}
if (!test_passed)
{
if (!test_passed) {
std::cerr << "Test failed\n";
exit(1);
}
}


int main(int argc, char **argv)
{
int main(int argc, char **argv) {
int M = 16;
int efConstruction = 200;
int num_threads = std::thread::hardware_concurrency();

bool update = false;

if (argc == 2)
{
if (std::string(argv[1]) == "update")
{
if (argc == 2) {
if (std::string(argv[1]) == "update") {
update = true;
std::cout << "Updates are on\n";
}
else {
std::cout<<"Usage ./test_updates [update]\n";
} else {
std::cout << "Usage ./test_updates [update]\n";
exit(1);
}
}
else if (argc>2){
std::cout<<"Usage ./test_updates [update]\n";
} else if (argc > 2) {
std::cout << "Usage ./test_updates [update]\n";
exit(1);
}

Expand All @@ -224,8 +203,7 @@ int main(int argc, char **argv)
{
std::ifstream configfile;
configfile.open(path + "/config.txt");
if (!configfile.is_open())
{
if (!configfile.is_open()) {
std::cout << "Cannot open config.txt\n";
return 1;
}
Expand All @@ -245,10 +223,9 @@ int main(int argc, char **argv)

StopW stopw = StopW();

if (update)
{
if (update) {
std::cout << "Update iteration 0\n";

ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
});
Expand All @@ -259,14 +236,13 @@ int main(int argc, char **argv)
});
appr_alg.checkIntegrity();

for (int b = 1; b < dummy_data_multiplier; b++)
{
for (int b = 1; b < dummy_data_multiplier; b++) {
std::cout << "Update iteration " << b << "\n";
char cpath[1024];
sprintf(cpath, "batch_dummy_%02d.bin", b);
std::vector<float> dummy_batchb = load_batch<float>(path + cpath, N * d);
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {

ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
});
appr_alg.checkIntegrity();
Expand All @@ -275,31 +251,28 @@ int main(int argc, char **argv)

std::cout << "Inserting final elements\n";
std::vector<float> final_batch = load_batch<float>(path + "batch_final.bin", N * d);

stopw.reset();
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(final_batch.data() + i * d), i);
});
std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
std::cout << "Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
std::cout << "Running tests\n";
std::vector<float> queries_batch = load_batch<float>(path + "queries.bin", N_queries * d);

std::vector<int> gt = load_batch<int>(path + "gt.bin", N_queries * K);

std::vector<std::unordered_set<hnswlib::labeltype>> answers(N_queries);
for (int i = 0; i < N_queries; i++)
{
for (int j = 0; j < K; j++)
{
for (int i = 0; i < N_queries; i++) {
for (int j = 0; j < K; j++) {
answers[i].insert(gt[i * K + j]);
}
}

for (int i = 0; i < 3; i++)
{
for (int i = 0; i < 3; i++) {
std::cout << "Test iteration " << i << "\n";
test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K);
}

return 0;
};
}
Loading