From 07eca1a41b95d622e8d86dab0e92998f0d4fe673 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Wed, 2 Mar 2022 15:12:55 -0800 Subject: [PATCH 1/7] Integrate filtering support for ANN --- docs/reference/search/knn-search.asciidoc | 6 ++ .../index/query/AbstractQueryBuilder.java | 4 +- .../DocumentLevelSecurityTests.java | 9 ++- .../integration/FieldLevelSecurityTests.java | 20 +++++- .../test/vectors/40_knn_search.yml | 38 +++++++++- .../action/KnnSearchRequestBuilder.java | 30 ++++++-- .../mapper/DenseVectorFieldMapper.java | 4 +- .../vectors/query/KnnVectorQueryBuilder.java | 69 +++++++++++++++++-- .../action/KnnSearchRequestBuilderTests.java | 56 ++++++++++++++- .../mapper/DenseVectorFieldTypeTests.java | 6 +- .../query/KnnVectorQueryBuilderTests.java | 65 ++++++++++++++++- 11 files changed, 280 insertions(+), 27 deletions(-) diff --git a/docs/reference/search/knn-search.asciidoc b/docs/reference/search/knn-search.asciidoc index 80638422080bf..99c418a115c26 100644 --- a/docs/reference/search/knn-search.asciidoc +++ b/docs/reference/search/knn-search.asciidoc @@ -122,6 +122,12 @@ shard, then merges them to find the top `k` results. Increasing `num_candidates` tends to improve the accuracy of the final `k` results. ==== +`filter`:: +(Optional, <>) Query to filter the documents that +can match. The kNN search will return the top `k` documents that also match +this filter. The value can be a single query or a list of queries. If `filter` +is not provided, all documents are allowed to match. + include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def] include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def] include::{es-repo-dir}/search/search.asciidoc[tag=source-filtering-def] diff --git a/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java index f1a31db6125d2..6a7cc4243ff9c 100644 --- a/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java @@ -251,14 +251,14 @@ public String getName() { return getWriteableName(); } - static void writeQueries(StreamOutput out, List queries) throws IOException { + protected static void writeQueries(StreamOutput out, List queries) throws IOException { out.writeVInt(queries.size()); for (QueryBuilder query : queries) { out.writeNamedWriteable(query); } } - static List readQueries(StreamInput in) throws IOException { + protected static List readQueries(StreamInput in) throws IOException { int size = in.readVInt(); List queries = new ArrayList<>(size); for (int i = 0; i < size; i++) { diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java index a0cddd6aa6a9a..c7921582b0ab1 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java @@ -37,6 +37,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.TermsQueryBuilder; +import org.elasticsearch.index.query.WildcardQueryBuilder; import org.elasticsearch.indices.IndicesRequestCache; import org.elasticsearch.indices.TermsLookup; import org.elasticsearch.join.ParentJoinPlugin; @@ -889,8 +890,8 @@ public void testKnnSearch() throws Exception { assertAcked(client().admin().indices().prepareCreate("test").setSettings(indexSettings).setMapping(builder)); for (int i = 0; i < 5; i++) { - client().prepareIndex("test").setSource("field1", "value1", "vector", new float[] { i, i, i }).get(); - client().prepareIndex("test").setSource("field2", "value2", "vector", new float[] { i, i, i }).get(); + client().prepareIndex("test").setSource("field1", "value1", "other", "valueA", "vector", new float[] { i, i, i }).get(); + client().prepareIndex("test").setSource("field2", "value2", "other", "valueB", "vector", new float[] { i, i, i }).get(); } client().admin().indices().prepareRefresh("test").get(); @@ -900,6 +901,10 @@ public void testKnnSearch() throws Exception { float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f }; KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50); + if (randomBoolean()) { + query.filterQuery(new WildcardQueryBuilder("other", "value*")); + } + // user1 should only be able to see docs with field1: value1 SearchResponse response = client().filterWithHeader( Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)) diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java index 41a96fee77231..8e1a5914106c5 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java @@ -284,7 +284,7 @@ public void testQuery() { .get(); assertHitCount(response, 1); - // user1 has no access to field1, so the query should not match with the document: + // user1 has no access to field2, so the query should not match with the document: response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) .prepareSearch("test") .setQuery(matchQuery("field2", "value2")) @@ -399,7 +399,7 @@ public void testKnnSearch() throws IOException { assertAcked(client().admin().indices().prepareCreate("test").setMapping(builder)); client().prepareIndex("test") - .setSource("field1", "value1", "vector", new float[] { 0.0f, 0.0f, 0.0f }) + .setSource("field1", "value1", "field2", "value2", "vector", new float[] { 0.0f, 0.0f, 0.0f }) .setRefreshPolicy(IMMEDIATE) .get(); @@ -430,6 +430,22 @@ public void testKnnSearch() throws IOException { .get(); assertHitCount(response, 1); assertNull(response.getHits().getAt(0).field("vector")); + + // user1 can access field1, so the filtered query should match with the document: + QueryBuilder matchQuery = QueryBuilders.matchQuery("field1", "value1"); + response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) + .prepareSearch("test") + .setQuery(query.filterQuery(matchQuery)) + .get(); + assertHitCount(response, 1); + + // user1 cannot access field2, so the filtered query should not match with the document: + matchQuery = QueryBuilders.matchQuery("field2", "value2"); + response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) + .prepareSearch("test") + .setQuery(query.filterQuery(matchQuery)) + .get(); + assertHitCount(response, 0); } public void testPercolateQueryWithIndexedDocWithFLS() { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml index 18aaf2ab8264e..5b400db65a167 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml @@ -60,6 +60,43 @@ setup: - match: {hits.hits.1._id: "3"} - match: {hits.hits.1.fields.name.0: "rabbit.jpg"} +--- +"kNN search with filter": + - do: + knn_search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 2 + num_candidates: 3 + filter: + term: + name: "rabbit.jpg" + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - do: + knn_search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 2 + num_candidates: 3 + filter: + - term: + name: "rabbit.jpg" + - term: + _id: 2 + + - match: {hits.total.value: 0} + --- "Test nonexistent field": - do: @@ -81,7 +118,6 @@ setup: - do: catch: bad_request search: - rest_total_hits_as_int: true index: test-index body: query: diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java index a23f31a401376..9ddca9f803f0c 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java @@ -9,6 +9,8 @@ import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.StoredFieldsContext; @@ -37,11 +39,18 @@ class KnnSearchRequestBuilder { static final String ROUTING_PARAM = "routing"; static final ParseField KNN_SECTION_FIELD = new ParseField("knn"); + static final ParseField FILTER_FIELD = new ParseField("filter"); private static final ObjectParser PARSER; static { PARSER = new ObjectParser<>("knn-search"); PARSER.declareField(KnnSearchRequestBuilder::knnSearch, KnnSearch::parse, KNN_SECTION_FIELD, ObjectParser.ValueType.OBJECT); + PARSER.declareFieldArray( + KnnSearchRequestBuilder::filter, + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p), + FILTER_FIELD, + ObjectParser.ValueType.OBJECT_ARRAY + ); PARSER.declareField( (p, request, c) -> request.fetchSource(FetchSourceContext.fromXContent(p)), SearchSourceBuilder._SOURCE_FIELD, @@ -86,6 +95,7 @@ static KnnSearchRequestBuilder parseRestRequest(RestRequest restRequest) throws private final String[] indices; private String routing; private KnnSearch knnSearch; + private List filters; private FetchSourceContext fetchSource; private List fields; @@ -103,6 +113,10 @@ private void knnSearch(KnnSearch knnSearch) { this.knnSearch = knnSearch; } + private void filter(List filter) { + this.filters = filter; + } + /** * A comma separated list of routing values to control the shards the search will be executed on. */ @@ -152,17 +166,22 @@ public void build(SearchRequestBuilder builder) { if (knnSearch == null) { throw new IllegalArgumentException("missing required [" + KNN_SECTION_FIELD.getPreferredName() + "] section in search body"); } - knnSearch.build(sourceBuilder); + + KnnVectorQueryBuilder queryBuilder = knnSearch.buildQuery(); + if (filters != null) { + queryBuilder.filterQueries(this.filters); + } + + sourceBuilder.query(queryBuilder); + sourceBuilder.size(knnSearch.k); sourceBuilder.fetchSource(fetchSource); sourceBuilder.storedFields(storedFields); - if (fields != null) { for (FieldAndFormat field : fields) { sourceBuilder.fetchField(field); } } - if (docValueFields != null) { for (FieldAndFormat field : docValueFields) { sourceBuilder.docValueField(field.field, field.format); @@ -221,7 +240,7 @@ public static KnnSearch parse(XContentParser parser) throws IOException { this.numCands = numCands; } - void build(SearchSourceBuilder builder) { + public KnnVectorQueryBuilder buildQuery() { // We perform validation here instead of the constructor because it makes the errors // much clearer. Otherwise, the error message is deeply nested under parsing exceptions. if (k < 1) { @@ -236,8 +255,7 @@ void build(SearchSourceBuilder builder) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - builder.query(new KnnVectorQueryBuilder(field, queryVector, numCands)); - builder.size(k); + return new KnnVectorQueryBuilder(field, queryVector, numCands); } @Override diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java index bfad97032f6b6..f152f8ba732dd 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java @@ -301,7 +301,7 @@ public Query termQuery(Object value, SearchExecutionContext context) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries"); } - public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) { + public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands, Query filter) { if (isIndexed() == false) { throw new IllegalArgumentException( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" @@ -321,7 +321,7 @@ public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) { } checkVectorMagnitude(queryVector, squaredMagnitude); } - return new KnnVectorQuery(name(), queryVector, numCands); + return new KnnVectorQuery(name(), queryVector, numCands, filter); } private void checkVectorMagnitude(float[] vector, float squaredMagnitude) { diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java index 3def93ed0d461..c40ddd67ef52f 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java @@ -7,19 +7,26 @@ package org.elasticsearch.xpack.vectors.query; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper; import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper.DenseVectorFieldType; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Objects; public class KnnVectorQueryBuilder extends AbstractQueryBuilder { @@ -28,11 +35,13 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder filterQueries; public KnnVectorQueryBuilder(String fieldName, float[] queryVector, int numCands) { this.fieldName = fieldName; this.queryVector = queryVector; this.numCands = numCands; + this.filterQueries = List.of(); } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -40,6 +49,11 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { this.fieldName = in.readString(); this.numCands = in.readVInt(); this.queryVector = in.readFloatArray(); + if (in.getVersion().before(Version.V_8_2_0)) { + this.filterQueries = List.of(); + } else { + this.filterQueries = readQueries(in); + } } public String getFieldName() { @@ -54,11 +68,28 @@ public int numCands() { return numCands; } + public List filterQueries() { + return filterQueries; + } + + public KnnVectorQueryBuilder filterQuery(QueryBuilder filterQuery) { + this.filterQueries = List.of(filterQuery); + return this; + } + + public KnnVectorQueryBuilder filterQueries(List filterQueries) { + this.filterQueries = filterQueries; + return this; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeVInt(numCands); out.writeFloatArray(queryVector); + if (out.getVersion().onOrAfter(Version.V_8_2_0)) { + writeQueries(out, filterQueries); + } } @Override @@ -73,7 +104,27 @@ public String getWriteableName() { } @Override - protected Query doToQuery(SearchExecutionContext context) { + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + boolean changed = false; + List rewrittenQueries = new ArrayList<>(filterQueries.size()); + for (QueryBuilder query : filterQueries) { + QueryBuilder rewrittenQuery = query.rewrite(queryRewriteContext); + if (rewrittenQuery instanceof MatchNoneQueryBuilder) { + return rewrittenQuery; + } + if (rewrittenQuery != query) { + changed = true; + } + rewrittenQueries.add(rewrittenQuery); + } + if (changed) { + return new KnnVectorQueryBuilder(fieldName, queryVector, numCands).filterQueries(rewrittenQueries); + } + return this; + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); if (fieldType == null) { throw new IllegalArgumentException("field [" + fieldName + "] does not exist in the mapping"); @@ -85,18 +136,28 @@ protected Query doToQuery(SearchExecutionContext context) { ); } + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (QueryBuilder query : this.filterQueries) { + builder.add(query.toQuery(context), BooleanClause.Occur.FILTER); + } + BooleanQuery booleanQuery = builder.build(); + Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; + DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; - return vectorFieldType.createKnnQuery(queryVector, numCands); + return vectorFieldType.createKnnQuery(queryVector, numCands, filterQuery); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands); + return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands, filterQueries); } @Override protected boolean doEquals(KnnVectorQueryBuilder other) { - return Objects.equals(fieldName, other.fieldName) && Arrays.equals(queryVector, other.queryVector) && numCands == other.numCands; + return Objects.equals(fieldName, other.fieldName) + && Arrays.equals(queryVector, other.queryVector) + && numCands == other.numCands + && Objects.equals(filterQueries, other.filterQueries); } @Override diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java index 10508c9078335..28f2462b3fa21 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java @@ -11,27 +11,49 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.vectors.action.KnnSearchRequestBuilder.KnnSearch; import org.elasticsearch.xpack.vectors.query.KnnVectorQueryBuilder; +import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import static java.util.Collections.emptyList; import static org.elasticsearch.search.RandomSearchRequestGenerator.randomSearchSourceBuilder; import static org.hamcrest.Matchers.containsString; public class KnnSearchRequestBuilderTests extends ESTestCase { + private NamedXContentRegistry namedXContentRegistry; + + @Before + public void registerNamedXContents() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); + List namedXContents = searchModule.getNamedXContents(); + namedXContentRegistry = new NamedXContentRegistry(namedXContents); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return namedXContentRegistry; + } public void testBuildSearchRequest() throws IOException { // Choose random REST parameters @@ -47,6 +69,7 @@ public void testBuildSearchRequest() throws IOException { // Create random request body KnnSearch knnSearch = randomKnnSearch(); + List filterQueries = randomFilterQueries(); SearchSourceBuilder searchSource = randomSearchSourceBuilder( () -> null, () -> null, @@ -55,7 +78,7 @@ public void testBuildSearchRequest() throws IOException { () -> null, () -> null ); - XContentBuilder builder = createRequestBody(knnSearch, searchSource); + XContentBuilder builder = createRequestBody(knnSearch, filterQueries, searchSource); // Convert the REST request to a search request and check the components SearchRequestBuilder searchRequestBuilder = buildSearchRequest(builder, params); @@ -64,7 +87,10 @@ public void testBuildSearchRequest() throws IOException { assertArrayEquals(indices, searchRequest.indices()); assertEquals(routing, searchRequest.routing()); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(knnSearch.field, knnSearch.queryVector, knnSearch.numCands); + KnnVectorQueryBuilder query = knnSearch.buildQuery(); + if (filterQueries.isEmpty() == false) { + query.filterQueries(filterQueries); + } assertEquals(query, searchRequest.source().query()); assertEquals(knnSearch.k, searchRequest.source().size()); @@ -215,7 +241,18 @@ private KnnSearch randomKnnSearch() { return new KnnSearch(field, vector, k, numCands); } - private XContentBuilder createRequestBody(KnnSearch knnSearch, SearchSourceBuilder searchSource) throws IOException { + private List randomFilterQueries() { + List filters = new ArrayList<>(); + int numFilters = randomIntBetween(0, 3); + for (int i = 0; i < numFilters; i++) { + QueryBuilder filter = QueryBuilders.termQuery(randomAlphaOfLength(5), randomAlphaOfLength(10)); + filters.add(filter); + } + return filters; + } + + private XContentBuilder createRequestBody(KnnSearch knnSearch, List filters, SearchSourceBuilder searchSource) + throws IOException { XContentType xContentType = randomFrom(XContentType.values()); XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); builder.startObject(); @@ -227,6 +264,19 @@ private XContentBuilder createRequestBody(KnnSearch knnSearch, SearchSourceBuild .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), knnSearch.queryVector) .endObject(); + if (filters.isEmpty() == false) { + builder.field(KnnSearchRequestBuilder.FILTER_FIELD.getPreferredName()); + if (filters.size() > 1) { + builder.startArray(); + } + for (QueryBuilder filter : filters) { + filter.toXContent(builder, ToXContent.EMPTY_PARAMS); + } + if (filters.size() > 1) { + builder.endArray(); + } + } + if (searchSource.fetchSource() != null) { builder.field(SearchSourceBuilder._SOURCE_FIELD.getPreferredName()); searchSource.fetchSource().toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java index 81622112578e6..2f8571809297a 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java @@ -76,7 +76,7 @@ public void testCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10) + () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -88,7 +88,7 @@ public void testCreateKnnQuery() { VectorSimilarity.dot_product, Collections.emptyMap() ); - e = expectThrows(IllegalArgumentException.class, () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10)); + e = expectThrows(IllegalArgumentException.class, () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null)); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); DenseVectorFieldType cosineField = new DenseVectorFieldType( @@ -99,7 +99,7 @@ public void testCreateKnnQuery() { VectorSimilarity.cosine, Collections.emptyMap() ); - e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10)); + e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null)); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java index eba35f342d21a..f234fc90a0b6b 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java @@ -9,24 +9,36 @@ import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.elasticsearch.Version; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractBuilderTestCase; import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; +import org.elasticsearch.test.VersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.vectors.DenseVectorPlugin; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.List; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class KnnVectorQueryBuilderTests extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; @@ -65,13 +77,24 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws @Override protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; - float[] vector = new float[VECTOR_DIMENSION]; for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } int numCands = randomIntBetween(1, 1000); - return new KnnVectorQueryBuilder(fieldName, vector, numCands); + + KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(fieldName, vector, numCands); + + if (randomBoolean()) { + List filters = new ArrayList<>(); + int numFilters = randomIntBetween(1, 5); + for (int i = 0; i < numFilters; i++) { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + queryBuilder.filterQueries(filters); + } + return queryBuilder; } @Override @@ -126,6 +149,44 @@ public void testValidOutput() { assertEquals(expected, query.toString()); } + @Override + public void testMustRewrite() throws IOException { + SearchExecutionContext context = createSearchExecutionContext(); + context.setAllowUnmappedFields(true); + TermQueryBuilder termQuery = new TermQueryBuilder("unmapped_field", 42); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION); + query.filterQuery(termQuery); + + IllegalStateException e = expectThrows(IllegalStateException.class, () -> query.toQuery(context)); + assertEquals("Rewrite first", e.getMessage()); + + QueryBuilder rewrittenQuery = query.rewrite(context); + assertThat(rewrittenQuery, instanceOf(MatchNoneQueryBuilder.class)); + } + + public void testOldVersionSerialization() throws IOException { + KnnVectorQueryBuilder query = createTestQueryBuilder(); + KnnVectorQueryBuilder queryWithNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), query.queryVector(), query.numCands()); + queryWithNoFilters.queryName(query.queryName()).boost(query.boost()); + + Version newVersion = VersionUtils.randomVersionBetween(random(), Version.V_8_2_0, Version.CURRENT); + Version oldVersion = VersionUtils.randomVersionBetween(random(), Version.V_8_0_0, Version.V_8_1_0); + + assertSerialization(query, newVersion); + assertSerialization(queryWithNoFilters, oldVersion); + + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(newVersion); + output.writeNamedWriteable(query); + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) { + in.setVersion(oldVersion); + KnnVectorQueryBuilder deserializedQuery = (KnnVectorQueryBuilder) in.readNamedWriteable(QueryBuilder.class); + assertEquals(queryWithNoFilters, deserializedQuery); + assertEquals(queryWithNoFilters.hashCode(), deserializedQuery.hashCode()); + } + } + } + @Override public void testUnknownObjectException() throws IOException { // Test isn't relevant, since query is never parsed from xContent From 86774715cab9c6dd40bad88d2cf9c8f246dda1e7 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Mon, 7 Mar 2022 13:12:35 -0800 Subject: [PATCH 2/7] Update docs/changelog/84734.yaml --- docs/changelog/84734.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/84734.yaml diff --git a/docs/changelog/84734.yaml b/docs/changelog/84734.yaml new file mode 100644 index 0000000000000..3c0b7cb656fa6 --- /dev/null +++ b/docs/changelog/84734.yaml @@ -0,0 +1,6 @@ +pr: 84734 +summary: Integrate filtering support for ANN +area: Search +type: enhancement +issues: + - 81788 From 8587492f7755ea0a60a117e9ae3dd7a298a4faed Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Wed, 9 Mar 2022 10:11:36 -0800 Subject: [PATCH 3/7] Improve REST yml test --- .../resources/rest-api-spec/test/vectors/40_knn_search.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml index 5b400db65a167..a27b406064d7d 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml @@ -54,6 +54,8 @@ setup: k: 2 num_candidates: 3 + - match: {hits.total.value: 2} + - match: {hits.hits.0._id: "2"} - match: {hits.hits.0.fields.name.0: "moose.jpg"} @@ -76,6 +78,7 @@ setup: term: name: "rabbit.jpg" + - match: {hits.total.value: 1} - match: {hits.hits.0._id: "3"} - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} From f65fcc04cd55c1d665d8d0e5520ba0f0eddcee40 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Wed, 9 Mar 2022 10:27:54 -0800 Subject: [PATCH 4/7] Add an example to docs --- docs/reference/search/knn-search.asciidoc | 102 +++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/docs/reference/search/knn-search.asciidoc b/docs/reference/search/knn-search.asciidoc index 99c418a115c26..ddab022559ceb 100644 --- a/docs/reference/search/knn-search.asciidoc +++ b/docs/reference/search/knn-search.asciidoc @@ -42,7 +42,7 @@ GET my-index/_knn_search "k": 10, "num_candidates": 100 }, - "_source": ["name", "date"] + "_source": ["name", "file_type"] } ---- // TEST[continued] @@ -128,6 +128,8 @@ can match. The kNN search will return the top `k` documents that also match this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents are allowed to match. + + include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def] include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def] include::{es-repo-dir}/search/search.asciidoc[tag=source-filtering-def] @@ -147,3 +149,101 @@ the similarity between the query and document vector. See * The `hits.total` object contains the total number of nearest neighbor candidates considered, which is `num_candidates * num_shards`. The `hits.total.relation` will always be `eq`, indicating an exact value. + +[[knn-search-api-example]] +==== {api-examples-title} + +The following requests create a `dense_vector` field with indexing enabled and +add sample documents: + +[source,console] +---- +PUT my-index +{ + "mappings": { + "properties": { + "image_vector": { + "type": "dense_vector", + "dims": 3, + "index": true, + "similarity": "l2_norm" + }, + "name": { + "type": "keyword" + }, + "file_type": { + "type": "keyword" + } + } + } +} + +PUT my-index/_doc/1?refresh +{ + "image_vector" : [0.5, 0.1, 2.6], + "name": "moose family", + "file_type": "jpeg" +} + +PUT my-index/_doc/2?refresh +{ + "image_vector" : [1.0, 0.8, -0.2], + "name": "alpine lake", + "file_type": "svg" +} +---- + +The next request performs a kNN search filtered by the `file_type` field: + +[source,console] +---- +GET my-index/_knn_search +{ + "knn": { + "field": "image_vector", + "query_vector": [0.3, 0.1, 1.2], + "k": 5, + "num_candidates": 50 + }, + "filter": { + "term": { + "file_type": "svg" + } + }, + "_source": ["name"] +} +---- +// TEST[continued] + +[source,console-result] +---- +{ + "took": 5, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 1, + "relation": "eq" + }, + "max_score": 0.2538071, + "hits": [ + { + "_index": "my-index", + "_id": "2", + "_score": 0.2538071, + "_source": { + "name": "alpine lake" + } + } + ] + } +} +---- +// TESTRESPONSE[s/"took": 5/"took": $body.took/] +// TESTRESPONSE[s/,\n \.\.\.//] From 0f5a863d3d960b330769f43fb90103cbcd344acc Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Wed, 9 Mar 2022 14:02:53 -0800 Subject: [PATCH 5/7] Remove invalid check from REST test --- .../resources/rest-api-spec/test/vectors/40_knn_search.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml index a27b406064d7d..80d2ee0c762fb 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/40_knn_search.yml @@ -54,8 +54,6 @@ setup: k: 2 num_candidates: 3 - - match: {hits.total.value: 2} - - match: {hits.hits.0._id: "2"} - match: {hits.hits.0.fields.name.0: "moose.jpg"} From b401564c0e33b5c6503a66816c671104410dd52b Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Thu, 10 Mar 2022 13:50:11 -0800 Subject: [PATCH 6/7] Improve javadoc on KnnVectorQueryBuilder --- .../xpack/vectors/query/KnnVectorQueryBuilder.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java index c40ddd67ef52f..3398ed6991596 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java @@ -9,6 +9,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; @@ -29,6 +30,11 @@ import java.util.List; import java.util.Objects; +/** + * A query that performs kNN search using Lucene's {@link KnnVectorQuery}. + * + * NOTE: this is an internal class and should not be used outside of core Elasticsearch code. + */ public class KnnVectorQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "knn"; From 03d3f7be7c2a6c0d4221bb1a2a9f91a5da7781bc Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Thu, 10 Mar 2022 13:59:59 -0800 Subject: [PATCH 7/7] Make KnnVectorQueryBuilder#filterQueries final --- .../DocumentLevelSecurityTests.java | 2 +- .../integration/FieldLevelSecurityTests.java | 12 ++++++++---- .../action/KnnSearchRequestBuilder.java | 2 +- .../vectors/query/KnnVectorQueryBuilder.java | 18 ++++++++++-------- .../action/KnnSearchRequestBuilderTests.java | 2 +- .../query/KnnVectorQueryBuilderTests.java | 4 ++-- 6 files changed, 23 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java index c7921582b0ab1..36195ea7a2d16 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java @@ -902,7 +902,7 @@ public void testKnnSearch() throws Exception { KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50); if (randomBoolean()) { - query.filterQuery(new WildcardQueryBuilder("other", "value*")); + query.addFilterQuery(new WildcardQueryBuilder("other", "value*")); } // user1 should only be able to see docs with field1: value1 diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java index 8e1a5914106c5..8a377e1f55efd 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java @@ -432,18 +432,22 @@ public void testKnnSearch() throws IOException { assertNull(response.getHits().getAt(0).field("vector")); // user1 can access field1, so the filtered query should match with the document: - QueryBuilder matchQuery = QueryBuilders.matchQuery("field1", "value1"); + KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery( + QueryBuilders.matchQuery("field1", "value1") + ); response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) .prepareSearch("test") - .setQuery(query.filterQuery(matchQuery)) + .setQuery(filterQuery1) .get(); assertHitCount(response, 1); // user1 cannot access field2, so the filtered query should not match with the document: - matchQuery = QueryBuilders.matchQuery("field2", "value2"); + KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery( + QueryBuilders.matchQuery("field2", "value2") + ); response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) .prepareSearch("test") - .setQuery(query.filterQuery(matchQuery)) + .setQuery(filterQuery2) .get(); assertHitCount(response, 0); } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java index 9ddca9f803f0c..3eae8c7c0927b 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilder.java @@ -169,7 +169,7 @@ public void build(SearchRequestBuilder builder) { KnnVectorQueryBuilder queryBuilder = knnSearch.buildQuery(); if (filters != null) { - queryBuilder.filterQueries(this.filters); + queryBuilder.addFilterQueries(this.filters); } sourceBuilder.query(queryBuilder); diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java index 3398ed6991596..cb690c993804d 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilder.java @@ -41,13 +41,13 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder filterQueries; + private final List filterQueries; public KnnVectorQueryBuilder(String fieldName, float[] queryVector, int numCands) { this.fieldName = fieldName; this.queryVector = queryVector; this.numCands = numCands; - this.filterQueries = List.of(); + this.filterQueries = new ArrayList<>(); } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -56,7 +56,7 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { this.numCands = in.readVInt(); this.queryVector = in.readFloatArray(); if (in.getVersion().before(Version.V_8_2_0)) { - this.filterQueries = List.of(); + this.filterQueries = new ArrayList<>(); } else { this.filterQueries = readQueries(in); } @@ -78,13 +78,15 @@ public List filterQueries() { return filterQueries; } - public KnnVectorQueryBuilder filterQuery(QueryBuilder filterQuery) { - this.filterQueries = List.of(filterQuery); + public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) { + Objects.requireNonNull(filterQuery); + this.filterQueries.add(filterQuery); return this; } - public KnnVectorQueryBuilder filterQueries(List filterQueries) { - this.filterQueries = filterQueries; + public KnnVectorQueryBuilder addFilterQueries(List filterQueries) { + Objects.requireNonNull(filterQueries); + this.filterQueries.addAll(filterQueries); return this; } @@ -124,7 +126,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnVectorQueryBuilder(fieldName, queryVector, numCands).filterQueries(rewrittenQueries); + return new KnnVectorQueryBuilder(fieldName, queryVector, numCands).addFilterQueries(rewrittenQueries); } return this; } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java index 28f2462b3fa21..09958a2809b97 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/action/KnnSearchRequestBuilderTests.java @@ -89,7 +89,7 @@ public void testBuildSearchRequest() throws IOException { KnnVectorQueryBuilder query = knnSearch.buildQuery(); if (filterQueries.isEmpty() == false) { - query.filterQueries(filterQueries); + query.addFilterQueries(filterQueries); } assertEquals(query, searchRequest.source().query()); assertEquals(knnSearch.k, searchRequest.source().size()); diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java index f234fc90a0b6b..2f6f76803f0f3 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnVectorQueryBuilderTests.java @@ -92,7 +92,7 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); } - queryBuilder.filterQueries(filters); + queryBuilder.addFilterQueries(filters); } return queryBuilder; } @@ -155,7 +155,7 @@ public void testMustRewrite() throws IOException { context.setAllowUnmappedFields(true); TermQueryBuilder termQuery = new TermQueryBuilder("unmapped_field", 42); KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION); - query.filterQuery(termQuery); + query.addFilterQuery(termQuery); IllegalStateException e = expectThrows(IllegalStateException.class, () -> query.toQuery(context)); assertEquals("Rewrite first", e.getMessage());