From 765c4ab4ba00e2e8a54b349c5df1f028b08953ed Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Mon, 15 Aug 2022 15:55:57 +0530 Subject: [PATCH 01/11] Filter elements with an optional filtering function. --- CMakeLists.txt | 3 + examples/searchKnnWithFilter_test.cpp | 95 +++++++++++++++++++++++++++ hnswlib/bruteforce.h | 18 +++-- hnswlib/hnswalg.h | 16 ++--- hnswlib/hnswlib.h | 20 ++++-- 5 files changed, 131 insertions(+), 21 deletions(-) create mode 100644 examples/searchKnnWithFilter_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e2f3d716..e42d6cee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp) target_link_libraries(searchKnnCloserFirst_test hnswlib) + add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp) + target_link_libraries(searchKnnWithFilter_test hnswlib) + add_executable(main main.cpp sift_1b.cpp) target_link_libraries(main hnswlib) endif() diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp new file mode 100644 index 00000000..290054d3 --- /dev/null +++ b/examples/searchKnnWithFilter_test.cpp @@ -0,0 +1,95 @@ +// This is a test file for testing the filtering feature + +#include "../hnswlib/hnswlib.h" + +#include + +#include +#include + +namespace +{ + +using idx_t = hnswlib::labeltype; + +bool pickIdsDivisibleByThree(unsigned int ep_id) { + return ep_id % 3 == 0; +} + +bool pickIdsDivisibleBySeven(unsigned int ep_id) { + return ep_id % 7 == 0; +} + +template +void test(filter_func_t filter_func, size_t div_num) { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + alg_brute->addPoint(data.data() + d * i, i); + alg_hnsw->addPoint(data.data() + d * i, i); + } + + // test searchKnnCloserFirst of BruteforceSearch with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k, filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + assert((gd.top().second % div_num) == 0); + gd.pop(); + } + } + + // test searchKnnCloserFirst of hnsw with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k, filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + assert((gd.top().second % div_num) == 0); + gd.pop(); + } + } + + delete alg_brute; + delete alg_hnsw; +} + +} // namespace + +int main() { + std::cout << "Testing ..." << std::endl; + test(pickIdsDivisibleByThree, 3); + test(pickIdsDivisibleBySeven, 7); + std::cout << "Test ok" << std::endl; + + return 0; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index f8e0aeb3..c16c19a0 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -5,8 +5,8 @@ #include namespace hnswlib { - template - class BruteforceSearch : public AlgorithmInterface { + template + class BruteforceSearch : public AlgorithmInterface { public: BruteforceSearch(SpaceInterface *s) : data_(nullptr), maxelements_(0), cur_element_count(0), size_per_element_(0), data_size_(0), @@ -92,20 +92,24 @@ namespace hnswlib { std::priority_queue> - searchKnn(const void *query_data, size_t k) const { + searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { std::priority_queue> topResults; if (cur_element_count == 0) return topResults; for (int i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + - data_size_)))); + labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + if(isIdAllowed(label)) { + topResults.push(std::pair(dist, label)); + } } dist_t lastdist = topResults.top().first; for (int i = k; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { - topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + - data_size_)))); + labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + if(isIdAllowed(label)) { + topResults.push(std::pair(dist, label)); + } if (topResults.size() > k) topResults.pop(); lastdist = topResults.top().first; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 8060683c..7ca41f9a 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -13,8 +13,8 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; - template - class HierarchicalNSW : public AlgorithmInterface { + template + class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; HierarchicalNSW(SpaceInterface *s) { @@ -238,7 +238,7 @@ namespace hnswlib { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -247,7 +247,7 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if (!has_deletions || !isMarkedDeleted(ep_id)) { + if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -307,7 +307,7 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - if (!has_deletions || !isMarkedDeleted(candidate_id)) + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id)) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) @@ -1111,7 +1111,7 @@ namespace hnswlib { }; std::priority_queue> - searchKnn(const void *query_data, size_t k) const { + searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1148,11 +1148,11 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> top_candidates; if (num_deleted_) { top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k)); + currObj, query_data, std::max(ef_, k), isIdAllowed); } else{ top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k)); + currObj, query_data, std::max(ef_, k), isIdAllowed); } while (top_candidates.size() > k) { diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 61029e90..fc48af29 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,6 +116,10 @@ static bool AVX512Capable() { namespace hnswlib { typedef size_t labeltype; + bool allowAllIds(unsigned int ep_id) { + return true; + } + template class pairGreater { public: @@ -137,6 +141,7 @@ namespace hnswlib { template using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + using FILTERFUNC = bool(*)(unsigned int); template class SpaceInterface { @@ -151,28 +156,31 @@ namespace hnswlib { virtual ~SpaceInterface() {} }; - template + template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; - virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; + + virtual std::priority_queue> + searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k) const; + searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const; virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } }; - template + template std::vector> - AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k) const { + AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + filter_func_t isIdAllowed) const { std::vector> result; // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k); + auto ret = searchKnn(query_data, k, isIdAllowed); { size_t sz = ret.size(); result.resize(sz); From ad3440c83555d9a76eef0e23bc6505c86b026716 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 19 Aug 2022 09:02:11 +0530 Subject: [PATCH 02/11] Filter function should be sent the label and not the internal ID. --- examples/searchKnnWithFilter_test.cpp | 12 ++++++------ hnswlib/hnswalg.h | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 290054d3..71a055dd 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -21,7 +21,7 @@ bool pickIdsDivisibleBySeven(unsigned int ep_id) { } template -void test(filter_func_t filter_func, size_t div_num) { +void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -40,15 +40,15 @@ void test(filter_func_t filter_func, size_t div_num) { for (idx_t i = 0; i < nq * d; ++i) { query[i] = distrib(rng); } - hnswlib::L2Space space(d); hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { - alg_brute->addPoint(data.data() + d * i, i); - alg_hnsw->addPoint(data.data() + d * i, i); + // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs + alg_brute->addPoint(data.data() + d * i, label_id_start + i); + alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); } // test searchKnnCloserFirst of BruteforceSearch with filtering @@ -87,8 +87,8 @@ void test(filter_func_t filter_func, size_t div_num) { int main() { std::cout << "Testing ..." << std::endl; - test(pickIdsDivisibleByThree, 3); - test(pickIdsDivisibleBySeven, 7); + test(pickIdsDivisibleByThree, 3, 17); + test(pickIdsDivisibleBySeven, 7, 17); std::cout << "Test ok" << std::endl; return 0; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 7ca41f9a..d319aa7e 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -247,7 +247,7 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) { + if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(getExternalLabel(ep_id))) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -307,7 +307,7 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id)) + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id))) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) From 4f6dcc38e8af8068cff455852e57a833d9ef6a22 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 19 Aug 2022 09:13:25 +0530 Subject: [PATCH 03/11] Ensure that results are not empty when reading from top results. --- examples/searchKnnWithFilter_test.cpp | 77 ++++++++++++++++++++++++--- hnswlib/bruteforce.h | 7 ++- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 71a055dd..b048baf7 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -12,16 +12,20 @@ namespace using idx_t = hnswlib::labeltype; -bool pickIdsDivisibleByThree(unsigned int ep_id) { - return ep_id % 3 == 0; +bool pickIdsDivisibleByThree(unsigned int label_id) { + return label_id % 3 == 0; } -bool pickIdsDivisibleBySeven(unsigned int ep_id) { - return ep_id % 7 == 0; +bool pickIdsDivisibleBySeven(unsigned int label_id) { + return label_id % 7 == 0; +} + +bool pickNothing(unsigned int label_id) { + return false; } template -void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) { +void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -83,12 +87,71 @@ void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) { delete alg_hnsw; } +template +void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs + alg_brute->addPoint(data.data() + d * i, label_id_start + i); + alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); + } + + // test searchKnnCloserFirst of BruteforceSearch with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k, filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + assert(0 == gd.size()); + } + + // test searchKnnCloserFirst of hnsw with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k, filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + assert(0 == gd.size()); + } + + delete alg_brute; + delete alg_hnsw; +} + } // namespace int main() { std::cout << "Testing ..." << std::endl; - test(pickIdsDivisibleByThree, 3, 17); - test(pickIdsDivisibleBySeven, 7, 17); + + // some of the elements are filtered + test_some_filtering(pickIdsDivisibleByThree, 3, 17); + test_some_filtering(pickIdsDivisibleBySeven, 7, 17); + + // all of the elements are filtered + test_none_filtering(pickNothing, 17); + std::cout << "Test ok" << std::endl; return 0; diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index c16c19a0..3de18eeb 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -102,7 +102,7 @@ namespace hnswlib { topResults.push(std::pair(dist, label)); } } - dist_t lastdist = topResults.top().first; + dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; for (int i = k; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { @@ -112,7 +112,10 @@ namespace hnswlib { } if (topResults.size() > k) topResults.pop(); - lastdist = topResults.top().first; + + if (!topResults.empty()) { + lastdist = topResults.top().first; + } } } From 1c833a73f504ab383bb7c31a036b71bf5e53a861 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 25 Aug 2022 15:56:08 +0530 Subject: [PATCH 04/11] Make allowAllIds static. --- hnswlib/hnswlib.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index fc48af29..0b6f84a6 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,7 +116,7 @@ static bool AVX512Capable() { namespace hnswlib { typedef size_t labeltype; - bool allowAllIds(unsigned int ep_id) { + static bool allowAllIds(unsigned int ep_id) { return true; } From aaee13a931c9b4320cfc88e5ab549199d405f787 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 26 Aug 2022 16:12:26 +0530 Subject: [PATCH 05/11] Use functor for filtering. --- examples/searchKnnWithFilter_test.cpp | 27 +++++++++++++++++++++------ hnswlib/bruteforce.h | 4 ++-- hnswlib/hnswalg.h | 6 +++--- hnswlib/hnswlib.h | 20 +++++++++++--------- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index b048baf7..9219be03 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -25,7 +25,7 @@ bool pickNothing(unsigned int label_id) { } template -void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label_id_start) { +void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -46,8 +46,8 @@ void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -88,7 +88,7 @@ void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label } template -void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { +void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -109,8 +109,8 @@ void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -142,6 +142,17 @@ void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { } // namespace +class CustomFilterFunctor: public hnswlib::FilterFunctor { + std::unordered_set allowed_values; + +public: + explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} + + constexpr bool operator()(unsigned int id) const { + return allowed_values.count(id) != 0; + } +}; + int main() { std::cout << "Testing ..." << std::endl; @@ -152,6 +163,10 @@ int main() { // all of the elements are filtered test_none_filtering(pickNothing, 17); + // functor style which can capture context + CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65}); + test_some_filtering(pickIdsDivisibleByThirteen, 13, 21); + std::cout << "Test ok" << std::endl; return 0; diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 3de18eeb..a56f75f1 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -5,7 +5,7 @@ #include namespace hnswlib { - template + template class BruteforceSearch : public AlgorithmInterface { public: BruteforceSearch(SpaceInterface *s) : data_(nullptr), maxelements_(0), @@ -92,7 +92,7 @@ namespace hnswlib { std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { + searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { std::priority_queue> topResults; if (cur_element_count == 0) return topResults; for (int i = 0; i < k; i++) { diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index d319aa7e..23fddcc1 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -13,7 +13,7 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; - template + template class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; @@ -238,7 +238,7 @@ namespace hnswlib { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -1111,7 +1111,7 @@ namespace hnswlib { }; std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { + searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { std::priority_queue> result; if (cur_element_count == 0) return result; diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 0b6f84a6..b1c88df5 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,9 +116,13 @@ static bool AVX512Capable() { namespace hnswlib { typedef size_t labeltype; - static bool allowAllIds(unsigned int ep_id) { - return true; - } + // This can be extended to store state for filtering (e.g. from a std::set) + struct FilterFunctor { + template + bool operator()(Args&&...) { return true; } + }; + + FilterFunctor allowAllIds; template class pairGreater { @@ -141,8 +145,6 @@ namespace hnswlib { template using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); - using FILTERFUNC = bool(*)(unsigned int); - template class SpaceInterface { public: @@ -156,17 +158,17 @@ namespace hnswlib { virtual ~SpaceInterface() {} }; - template + template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; virtual std::priority_queue> - searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0; + searchKnn(const void*, size_t, filter_func_t& isIdAllowed=allowAllIds) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const; + searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const; virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ @@ -176,7 +178,7 @@ namespace hnswlib { template std::vector> AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - filter_func_t isIdAllowed) const { + filter_func_t& isIdAllowed) const { std::vector> result; // here searchKnn returns the result in the order of further first From b87f6230dbe59e874b3099cfcab689b42e887a20 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sat, 27 Aug 2022 13:20:23 +0530 Subject: [PATCH 06/11] Explicitly check for filter functor being default. --- hnswlib/hnswalg.h | 6 ++++-- hnswlib/hnswlib.h | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 23fddcc1..d7fd385f 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -247,7 +247,8 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(getExternalLabel(ep_id))) { + bool is_filter_disabled = std::is_same::value; + if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -307,7 +308,8 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id))) + is_filter_disabled = std::is_same::value; + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index b1c88df5..d8997044 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -122,7 +122,7 @@ namespace hnswlib { bool operator()(Args&&...) { return true; } }; - FilterFunctor allowAllIds; + static FilterFunctor allowAllIds; template class pairGreater { From f0dedf3956de1762fa0b0611ac03b4c29d236bf5 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 28 Aug 2022 15:05:41 +0530 Subject: [PATCH 07/11] Remove duplicate assignment. --- hnswlib/hnswalg.h | 1 - 1 file changed, 1 deletion(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index d7fd385f..57fba444 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -308,7 +308,6 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - is_filter_disabled = std::is_same::value; if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) top_candidates.emplace(dist, candidate_id); From e4705fd3f09dd56d05278aebb0f6e1a25383e4f5 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 28 Aug 2022 15:11:22 +0530 Subject: [PATCH 08/11] Add search with filter test to CI. --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f8fde085..5e0c1f9d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -46,5 +46,6 @@ jobs: run: | cd build ./searchKnnCloserFirst_test + ./searchKnnWithFilter_test ./test_updates ./test_updates update From 7f419eaaa36c83b22e623a99c7be0d22ce47f4f4 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 28 Aug 2022 18:01:53 +0530 Subject: [PATCH 09/11] Remove constexpr for functor in test. --- examples/searchKnnWithFilter_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 9219be03..ead0c6fd 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -148,7 +148,7 @@ class CustomFilterFunctor: public hnswlib::FilterFunctor { public: explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} - constexpr bool operator()(unsigned int id) const { + bool operator()(unsigned int id) { return allowed_values.count(id) != 0; } }; From 1fe7baf525b8c07b08e1d048b8e258f4804069b1 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 6 Sep 2022 15:37:09 +0530 Subject: [PATCH 10/11] Add check for is_filter_disabled. --- hnswlib/bruteforce.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index a56f75f1..33130273 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -93,12 +93,14 @@ namespace hnswlib { std::priority_queue> searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { + assert(k <= cur_element_count); std::priority_queue> topResults; if (cur_element_count == 0) return topResults; + bool is_filter_disabled = std::is_same::value; for (int i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if(isIdAllowed(label)) { + if(is_filter_disabled || isIdAllowed(label)) { topResults.push(std::pair(dist, label)); } } @@ -107,7 +109,7 @@ namespace hnswlib { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); - if(isIdAllowed(label)) { + if(is_filter_disabled || isIdAllowed(label)) { topResults.push(std::pair(dist, label)); } if (topResults.size() > k) From c9897b0f730c48428b587a791cbf901a4550add8 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 6 Sep 2022 16:52:54 +0530 Subject: [PATCH 11/11] Add assert header. --- hnswlib/bruteforce.h | 1 + 1 file changed, 1 insertion(+) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 33130273..9fe97c09 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace hnswlib { template