-
Notifications
You must be signed in to change notification settings - Fork 739
Filter elements with an optional filtering function #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
765c4ab
Filter elements with an optional filtering function.
kishorenc ad3440c
Filter function should be sent the label and not the internal ID.
kishorenc 4f6dcc3
Ensure that results are not empty when reading from top results.
kishorenc 1c833a7
Make allowAllIds static.
kishorenc aaee13a
Use functor for filtering.
kishorenc b87f623
Explicitly check for filter functor being default.
kishorenc f0dedf3
Remove duplicate assignment.
kishorenc de22860
Merge branch 'develop' into filter-elements
kishorenc e4705fd
Add search with filter test to CI.
kishorenc 7f419ea
Remove constexpr for functor in test.
kishorenc 1fe7baf
Add check for is_filter_disabled.
kishorenc c9897b0
Add assert header.
kishorenc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// This is a test file for testing the filtering feature | ||
|
||
#include "../hnswlib/hnswlib.h" | ||
|
||
#include <assert.h> | ||
|
||
#include <vector> | ||
#include <iostream> | ||
|
||
namespace | ||
{ | ||
|
||
using idx_t = hnswlib::labeltype; | ||
|
||
bool pickIdsDivisibleByThree(unsigned int label_id) { | ||
return label_id % 3 == 0; | ||
} | ||
|
||
bool pickIdsDivisibleBySeven(unsigned int label_id) { | ||
return label_id % 7 == 0; | ||
} | ||
|
||
bool pickNothing(unsigned int label_id) { | ||
return false; | ||
} | ||
|
||
template<typename filter_func_t> | ||
void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) { | ||
int d = 4; | ||
idx_t n = 100; | ||
idx_t nq = 10; | ||
size_t k = 10; | ||
|
||
std::vector<float> data(n * d); | ||
std::vector<float> query(nq * d); | ||
|
||
std::mt19937 rng; | ||
rng.seed(47); | ||
std::uniform_real_distribution<> distrib; | ||
|
||
for (idx_t i = 0; i < n * d; ++i) { | ||
data[i] = distrib(rng); | ||
} | ||
for (idx_t i = 0; i < nq * d; ++i) { | ||
query[i] = distrib(rng); | ||
} | ||
|
||
hnswlib::L2Space space(d); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n); | ||
|
||
for (size_t i = 0; i < n; ++i) { | ||
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs | ||
alg_brute->addPoint(data.data() + d * i, label_id_start + i); | ||
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); | ||
} | ||
|
||
// test searchKnnCloserFirst of BruteforceSearch with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_brute->searchKnn(p, k, filter_func); | ||
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
size_t t = gd.size(); | ||
while (!gd.empty()) { | ||
assert(gd.top() == res[--t]); | ||
assert((gd.top().second % div_num) == 0); | ||
gd.pop(); | ||
} | ||
} | ||
|
||
// test searchKnnCloserFirst of hnsw with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_hnsw->searchKnn(p, k, filter_func); | ||
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
size_t t = gd.size(); | ||
while (!gd.empty()) { | ||
assert(gd.top() == res[--t]); | ||
assert((gd.top().second % div_num) == 0); | ||
gd.pop(); | ||
} | ||
} | ||
|
||
delete alg_brute; | ||
delete alg_hnsw; | ||
} | ||
|
||
template<typename filter_func_t> | ||
void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { | ||
int d = 4; | ||
idx_t n = 100; | ||
idx_t nq = 10; | ||
size_t k = 10; | ||
|
||
std::vector<float> data(n * d); | ||
std::vector<float> query(nq * d); | ||
|
||
std::mt19937 rng; | ||
rng.seed(47); | ||
std::uniform_real_distribution<> distrib; | ||
|
||
for (idx_t i = 0; i < n * d; ++i) { | ||
data[i] = distrib(rng); | ||
} | ||
for (idx_t i = 0; i < nq * d; ++i) { | ||
query[i] = distrib(rng); | ||
} | ||
|
||
hnswlib::L2Space space(d); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n); | ||
|
||
for (size_t i = 0; i < n; ++i) { | ||
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs | ||
alg_brute->addPoint(data.data() + d * i, label_id_start + i); | ||
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); | ||
} | ||
|
||
// test searchKnnCloserFirst of BruteforceSearch with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_brute->searchKnn(p, k, filter_func); | ||
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
assert(0 == gd.size()); | ||
} | ||
|
||
// test searchKnnCloserFirst of hnsw with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_hnsw->searchKnn(p, k, filter_func); | ||
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
assert(0 == gd.size()); | ||
} | ||
|
||
delete alg_brute; | ||
delete alg_hnsw; | ||
} | ||
|
||
} // namespace | ||
|
||
class CustomFilterFunctor: public hnswlib::FilterFunctor { | ||
std::unordered_set<unsigned int> allowed_values; | ||
|
||
public: | ||
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {} | ||
|
||
bool operator()(unsigned int id) { | ||
return allowed_values.count(id) != 0; | ||
} | ||
}; | ||
|
||
int main() { | ||
std::cout << "Testing ..." << std::endl; | ||
|
||
// some of the elements are filtered | ||
test_some_filtering(pickIdsDivisibleByThree, 3, 17); | ||
test_some_filtering(pickIdsDivisibleBySeven, 7, 17); | ||
|
||
// all of the elements are filtered | ||
test_none_filtering(pickNothing, 17); | ||
|
||
// functor style which can capture context | ||
CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65}); | ||
test_some_filtering(pickIdsDivisibleByThirteen, 13, 21); | ||
|
||
std::cout << "Test ok" << std::endl; | ||
|
||
return 0; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,8 @@ namespace hnswlib { | |
typedef unsigned int tableint; | ||
typedef unsigned int linklistsizeint; | ||
|
||
template<typename dist_t> | ||
class HierarchicalNSW : public AlgorithmInterface<dist_t> { | ||
template<typename dist_t, typename filter_func_t=FilterFunctor> | ||
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> { | ||
public: | ||
static const tableint max_update_element_locks = 65536; | ||
HierarchicalNSW(SpaceInterface<dist_t> *s) { | ||
|
@@ -238,7 +238,7 @@ namespace hnswlib { | |
|
||
template <bool has_deletions, bool collect_metrics=false> | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { | ||
VisitedList *vl = visited_list_pool_->getFreeVisitedList(); | ||
vl_type *visited_array = vl->mass; | ||
vl_type visited_array_tag = vl->curV; | ||
|
@@ -247,7 +247,8 @@ namespace hnswlib { | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; | ||
|
||
dist_t lowerBound; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:
|
||
if (!has_deletions || !isMarkedDeleted(ep_id)) { | ||
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value; | ||
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { | ||
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); | ||
lowerBound = dist; | ||
top_candidates.emplace(dist, ep_id); | ||
|
@@ -307,7 +308,7 @@ namespace hnswlib { | |
_MM_HINT_T0);//////////////////////// | ||
#endif | ||
|
||
if (!has_deletions || !isMarkedDeleted(candidate_id)) | ||
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) | ||
top_candidates.emplace(dist, candidate_id); | ||
|
||
if (top_candidates.size() > ef) | ||
|
@@ -1111,7 +1112,7 @@ namespace hnswlib { | |
}; | ||
|
||
std::priority_queue<std::pair<dist_t, labeltype >> | ||
searchKnn(const void *query_data, size_t k) const { | ||
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { | ||
std::priority_queue<std::pair<dist_t, labeltype >> result; | ||
if (cur_element_count == 0) return result; | ||
|
||
|
@@ -1148,11 +1149,11 @@ namespace hnswlib { | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; | ||
if (num_deleted_) { | ||
top_candidates=searchBaseLayerST<true,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
currObj, query_data, std::max(ef_, k), isIdAllowed); | ||
} | ||
else{ | ||
top_candidates=searchBaseLayerST<false,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
currObj, query_data, std::max(ef_, k), isIdAllowed); | ||
} | ||
|
||
while (top_candidates.size() > k) { | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add the same flag as in
hnswlib/hnswalg.h
and check ofk