Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/84734.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 84734
summary: Integrate filtering support for ANN
area: Search
type: enhancement
issues:
- 81788
108 changes: 107 additions & 1 deletion docs/reference/search/knn-search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ GET my-index/_knn_search
"k": 10,
"num_candidates": 100
},
"_source": ["name", "date"]
"_source": ["name", "file_type"]
}
----
// TEST[continued]
Expand Down Expand Up @@ -122,6 +122,14 @@ shard, then merges them to find the top `k` results. Increasing
`num_candidates` tends to improve the accuracy of the final `k` results.
====

`filter`::
(Optional, <<query-dsl,Query DSL object>>) Query to filter the documents that
can match. The kNN search will return the top `k` documents that also match
this filter. The value can be a single query or a list of queries. If `filter`
is not provided, all documents are allowed to match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have some json example as with filter as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll actually add an Examples section at the end of these docs.



include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def]
include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def]
include::{es-repo-dir}/search/search.asciidoc[tag=source-filtering-def]
Expand All @@ -141,3 +149,101 @@ the similarity between the query and document vector. See
* The `hits.total` object contains the total number of nearest neighbor
candidates considered, which is `num_candidates * num_shards`. The
`hits.total.relation` will always be `eq`, indicating an exact value.

[[knn-search-api-example]]
==== {api-examples-title}

The following requests create a `dense_vector` field with indexing enabled and
add sample documents:

[source,console]
----
PUT my-index
{
"mappings": {
"properties": {
"image_vector": {
"type": "dense_vector",
"dims": 3,
"index": true,
"similarity": "l2_norm"
},
"name": {
"type": "keyword"
},
"file_type": {
"type": "keyword"
}
}
}
}

PUT my-index/_doc/1?refresh
{
"image_vector" : [0.5, 0.1, 2.6],
"name": "moose family",
"file_type": "jpeg"
}

PUT my-index/_doc/2?refresh
{
"image_vector" : [1.0, 0.8, -0.2],
"name": "alpine lake",
"file_type": "svg"
}
----

The next request performs a kNN search filtered by the `file_type` field:

[source,console]
----
GET my-index/_knn_search
{
"knn": {
"field": "image_vector",
"query_vector": [0.3, 0.1, 1.2],
"k": 5,
"num_candidates": 50
},
"filter": {
"term": {
"file_type": "svg"
}
},
"_source": ["name"]
}
----
// TEST[continued]

[source,console-result]
----
{
"took": 5,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 1,
"relation": "eq"
},
"max_score": 0.2538071,
"hits": [
{
"_index": "my-index",
"_id": "2",
"_score": 0.2538071,
"_source": {
"name": "alpine lake"
}
}
]
}
}
----
// TESTRESPONSE[s/"took": 5/"took": $body.took/]
// TESTRESPONSE[s/,\n \.\.\.//]
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ public String getName() {
return getWriteableName();
}

static void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries) throws IOException {
protected static void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries) throws IOException {
out.writeVInt(queries.size());
for (QueryBuilder query : queries) {
out.writeNamedWriteable(query);
}
}

static List<QueryBuilder> readQueries(StreamInput in) throws IOException {
protected static List<QueryBuilder> readQueries(StreamInput in) throws IOException {
int size = in.readVInt();
List<QueryBuilder> queries = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.indices.IndicesRequestCache;
import org.elasticsearch.indices.TermsLookup;
import org.elasticsearch.join.ParentJoinPlugin;
Expand Down Expand Up @@ -889,8 +890,8 @@ public void testKnnSearch() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test").setSettings(indexSettings).setMapping(builder));

for (int i = 0; i < 5; i++) {
client().prepareIndex("test").setSource("field1", "value1", "vector", new float[] { i, i, i }).get();
client().prepareIndex("test").setSource("field2", "value2", "vector", new float[] { i, i, i }).get();
client().prepareIndex("test").setSource("field1", "value1", "other", "valueA", "vector", new float[] { i, i, i }).get();
client().prepareIndex("test").setSource("field2", "value2", "other", "valueB", "vector", new float[] { i, i, i }).get();
}

client().admin().indices().prepareRefresh("test").get();
Expand All @@ -900,6 +901,10 @@ public void testKnnSearch() throws Exception {
float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f };
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50);

if (randomBoolean()) {
query.addFilterQuery(new WildcardQueryBuilder("other", "value*"));
}

// user1 should only be able to see docs with field1: value1
SearchResponse response = client().filterWithHeader(
Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ public void testQuery() {
.get();
assertHitCount(response, 1);

// user1 has no access to field1, so the query should not match with the document:
// user1 has no access to field2, so the query should not match with the document:
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(matchQuery("field2", "value2"))
Expand Down Expand Up @@ -399,7 +399,7 @@ public void testKnnSearch() throws IOException {
assertAcked(client().admin().indices().prepareCreate("test").setMapping(builder));

client().prepareIndex("test")
.setSource("field1", "value1", "vector", new float[] { 0.0f, 0.0f, 0.0f })
.setSource("field1", "value1", "field2", "value2", "vector", new float[] { 0.0f, 0.0f, 0.0f })
.setRefreshPolicy(IMMEDIATE)
.get();

Expand Down Expand Up @@ -430,6 +430,26 @@ public void testKnnSearch() throws IOException {
.get();
assertHitCount(response, 1);
assertNull(response.getHits().getAt(0).field("vector"));

// user1 can access field1, so the filtered query should match with the document:
KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery(
QueryBuilders.matchQuery("field1", "value1")
);
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(filterQuery1)
.get();
assertHitCount(response, 1);

// user1 cannot access field2, so the filtered query should not match with the document:
KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery(
QueryBuilders.matchQuery("field2", "value2")
);
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(filterQuery2)
.get();
assertHitCount(response, 0);
}

public void testPercolateQueryWithIndexedDocWithFLS() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ setup:
- match: {hits.hits.1._id: "3"}
- match: {hits.hits.1.fields.name.0: "rabbit.jpg"}

---
"kNN search with filter":
- do:
knn_search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
k: 2
num_candidates: 3
filter:
term:
name: "rabbit.jpg"

- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}

- do:
knn_search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
k: 2
num_candidates: 3
filter:
- term:
name: "rabbit.jpg"
- term:
_id: 2

- match: {hits.total.value: 0}

---
"Test nonexistent field":
- do:
Expand All @@ -81,7 +119,6 @@ setup:
- do:
catch: bad_request
search:
rest_total_hits_as_int: true
index: test-index
body:
query:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
Expand Down Expand Up @@ -37,11 +39,18 @@ class KnnSearchRequestBuilder {
static final String ROUTING_PARAM = "routing";

static final ParseField KNN_SECTION_FIELD = new ParseField("knn");
static final ParseField FILTER_FIELD = new ParseField("filter");
private static final ObjectParser<KnnSearchRequestBuilder, Void> PARSER;

static {
PARSER = new ObjectParser<>("knn-search");
PARSER.declareField(KnnSearchRequestBuilder::knnSearch, KnnSearch::parse, KNN_SECTION_FIELD, ObjectParser.ValueType.OBJECT);
PARSER.declareFieldArray(
KnnSearchRequestBuilder::filter,
(p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p),
FILTER_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
);
PARSER.declareField(
(p, request, c) -> request.fetchSource(FetchSourceContext.fromXContent(p)),
SearchSourceBuilder._SOURCE_FIELD,
Expand Down Expand Up @@ -86,6 +95,7 @@ static KnnSearchRequestBuilder parseRestRequest(RestRequest restRequest) throws
private final String[] indices;
private String routing;
private KnnSearch knnSearch;
private List<QueryBuilder> filters;

private FetchSourceContext fetchSource;
private List<FieldAndFormat> fields;
Expand All @@ -103,6 +113,10 @@ private void knnSearch(KnnSearch knnSearch) {
this.knnSearch = knnSearch;
}

private void filter(List<QueryBuilder> filter) {
this.filters = filter;
}

/**
* A comma separated list of routing values to control the shards the search will be executed on.
*/
Expand Down Expand Up @@ -152,17 +166,22 @@ public void build(SearchRequestBuilder builder) {
if (knnSearch == null) {
throw new IllegalArgumentException("missing required [" + KNN_SECTION_FIELD.getPreferredName() + "] section in search body");
}
knnSearch.build(sourceBuilder);

KnnVectorQueryBuilder queryBuilder = knnSearch.buildQuery();
if (filters != null) {
queryBuilder.addFilterQueries(this.filters);
}

sourceBuilder.query(queryBuilder);
sourceBuilder.size(knnSearch.k);

sourceBuilder.fetchSource(fetchSource);
sourceBuilder.storedFields(storedFields);

if (fields != null) {
for (FieldAndFormat field : fields) {
sourceBuilder.fetchField(field);
}
}

if (docValueFields != null) {
for (FieldAndFormat field : docValueFields) {
sourceBuilder.docValueField(field.field, field.format);
Expand Down Expand Up @@ -221,7 +240,7 @@ public static KnnSearch parse(XContentParser parser) throws IOException {
this.numCands = numCands;
}

void build(SearchSourceBuilder builder) {
public KnnVectorQueryBuilder buildQuery() {
// We perform validation here instead of the constructor because it makes the errors
// much clearer. Otherwise, the error message is deeply nested under parsing exceptions.
if (k < 1) {
Expand All @@ -236,8 +255,7 @@ void build(SearchSourceBuilder builder) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}

builder.query(new KnnVectorQueryBuilder(field, queryVector, numCands));
builder.size(k);
return new KnnVectorQueryBuilder(field, queryVector, numCands);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public Query termQuery(Object value, SearchExecutionContext context) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
}

public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) {
public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands, Query filter) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
Expand All @@ -321,7 +321,7 @@ public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) {
}
checkVectorMagnitude(queryVector, squaredMagnitude);
}
return new KnnVectorQuery(name(), queryVector, numCands);
return new KnnVectorQuery(name(), queryVector, numCands, filter);
}

private void checkVectorMagnitude(float[] vector, float squaredMagnitude) {
Expand Down
Loading