From 1f033c56f9550c8fdfbe9bbe9ef9cccd2f47e030 Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Mon, 28 Jan 2019 09:12:11 -0500 Subject: [PATCH 1/8] Distance measures for dense and sparse vectors Introduce painless functions of cosineSimilarity and dotProduct distance measures for dense and sparse vector fields. ```js { "query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "cosineSimilarity(params.queryVector, doc['my_dense_vector'].value)", "params": { "queryVector": [4, 3.4, -1.2] } } } } } ``` ```js { "query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "cosineSimilaritySparse(params.queryVector, doc['my_sparse_vector'].value)", "params": { "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} } } } } } ``` Closes #31615 --- .../mapping/types/dense-vector.asciidoc | 3 +- .../mapping/types/sparse-vector.asciidoc | 3 +- .../query-dsl/script-score-query.asciidoc | 125 +++++++++++ modules/mapper-extras/build.gradle | 9 + .../index/mapper/DenseVectorFieldMapper.java | 4 +- .../index/mapper/SparseVectorFieldMapper.java | 4 +- .../index/mapper/VectorEncoderDecoder.java | 41 ++-- .../query/DocValuesWhitelistExtension.java | 42 ++++ .../index/query/ScoreScriptUtils.java | 202 ++++++++++++++++++ .../index/query/VectorDVAtomicFieldData.java | 65 ++++++ .../index/query/VectorDVIndexFieldData.java | 73 +++++++ .../index/query/VectorScriptDocValues.java | 63 ++++++ ...asticsearch.painless.spi.PainlessExtension | 1 + .../index/query/docvalues_whitelist.txt | 30 +++ .../test/dense-vector/10_basic.yml | 98 +++++++++ .../test/dense-vector/10_indexing.yml | 29 --- .../test/dense-vector/20_special_cases.yml | 133 ++++++++++++ .../test/sparse-vector/10_basic.yml | 98 +++++++++ .../test/sparse-vector/10_indexing.yml | 29 --- .../test/sparse-vector/20_special_cases.yml | 184 ++++++++++++++++ .../section/GreaterThanEqualToAssertion.java | 8 + .../section/LessThanOrEqualToAssertion.java | 8 + 22 files changed, 1176 insertions(+), 76 deletions(-) create mode 100644 modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java create mode 100644 modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java create mode 100644 modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java create mode 100644 modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java create mode 100644 modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java create mode 100644 modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension create mode 100644 modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt create mode 100644 modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml delete mode 100644 modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml create mode 100644 modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml create mode 100644 modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml delete mode 100644 modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml create mode 100644 modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index b97566361a05c..bd45874291c42 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -9,7 +9,8 @@ not exceed 500. The number of dimensions can be different across documents. A `dense_vector` field is a single-valued field. -These vectors can be used for document scoring. +These vectors can be used for +{ref}/query-dsl-script-score-query.html#vector-functions[document scoring]. For example, a document score can represent a distance between a given query vector and the indexed document vector. diff --git a/docs/reference/mapping/types/sparse-vector.asciidoc b/docs/reference/mapping/types/sparse-vector.asciidoc index 38561789b5d3f..739ce63d11026 100644 --- a/docs/reference/mapping/types/sparse-vector.asciidoc +++ b/docs/reference/mapping/types/sparse-vector.asciidoc @@ -9,7 +9,8 @@ not exceed 500. The number of dimensions can be different across documents. A `sparse_vector` field is a single-valued field. -These vectors can be used for document scoring. +These vectors can be used for +{ref}/query-dsl-script-score-query.html#vector-functions[document scoring]. For example, a document score can represent a distance between a given query vector and the indexed document vector. diff --git a/docs/reference/query-dsl/script-score-query.asciidoc b/docs/reference/query-dsl/script-score-query.asciidoc index cdcfd0f0a5032..d27a46f5b5051 100644 --- a/docs/reference/query-dsl/script-score-query.asciidoc +++ b/docs/reference/query-dsl/script-score-query.asciidoc @@ -74,6 +74,131 @@ to be the most efficient by using the internal mechanisms. -------------------------------------------------- // NOTCONSOLE +[[vector-functions]] +===== Distance functions for vector fields +These functions are used to calculate distances +for <> and +<> fields. + +For dense_vector fields, `cosineSimilarity` calculates the measure of +cosine similarity between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilarity(params.queryVector, doc['my_dense_vector'].value)", + "params": { + "queryVector": [4, 3.4, -1.2] <1> + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE +<1> To take advantage of the script optimizations, supply a query vector in script parameters. + +Similarly, for sparse_vector fields, `cosineSimilaritySparse` calculates cosine similarity +between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilaritySparse(params.queryVector, doc['my_sparse_vector'].value)", + "params": { + "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + +For dense_vector fields, `dotProduct` calculates the measure of +dot product between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "dotProduct(params.queryVector, doc['my_dense_vector'].value)", + "params": { + "queryVector": [4, 3.4, -1.2] + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + +Similarly, for sparse_vector fields, `dotProductSparse` calculates dot product +between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "dotProductSparse(params.queryVector, doc['my_sparse_vector'].value)", + "params": { + "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + + +Be aware, if a document doesn't have a value for a vector field on which +a distance function is executed, an error will be thrown. You can +guard against missing values as shown below, and provide your own +custom score for this case. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "doc['my_dense_vector'].value == null ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)", + "params": { + "queryVector": [4, 3.4, -1.2] + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + [[random-functions]] ===== Random functions diff --git a/modules/mapper-extras/build.gradle b/modules/mapper-extras/build.gradle index 7831de3a68e94..73fc8901ec774 100644 --- a/modules/mapper-extras/build.gradle +++ b/modules/mapper-extras/build.gradle @@ -20,4 +20,13 @@ esplugin { description 'Adds advanced field mappers' classname 'org.elasticsearch.index.mapper.MapperExtrasPlugin' + extendedPlugins = ['lang-painless'] } + +dependencies { + compileOnly project(':modules:lang-painless') +} + +integTestCluster { + module project(':modules:lang-painless') +} \ No newline at end of file diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java index 7beddc13ca598..57c996e977532 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.index.query.VectorDVIndexFieldData; import org.elasticsearch.search.DocValueFormat; import java.io.IOException; @@ -119,8 +120,7 @@ public Query existsQuery(QueryShardContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) { - throw new UnsupportedOperationException( - "Field [" + name() + "] of type [" + typeName() + "] doesn't support sorting, scripting or aggregating"); + return new VectorDVIndexFieldData.Builder(); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java index f7288d5039390..36146060ef33e 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.index.query.VectorDVIndexFieldData; import org.elasticsearch.search.DocValueFormat; import java.io.IOException; @@ -119,8 +120,7 @@ public Query existsQuery(QueryShardContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) { - throw new UnsupportedOperationException( - "Field [" + name() + "] of type [" + typeName() + "] doesn't support sorting, scripting or aggregating"); + return new VectorDVIndexFieldData.Builder(); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java index c21b006c8836b..ece4574f0c4d2 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java @@ -23,7 +23,7 @@ import org.apache.lucene.util.InPlaceMergeSorter; // static utility functions for encoding and decoding dense_vector and sparse_vector fields -final class VectorEncoderDecoder { +public final class VectorEncoderDecoder { static final byte INT_BYTES = 4; static final byte SHORT_BYTES = 2; @@ -34,7 +34,8 @@ private VectorEncoderDecoder() { } * BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension * @param values - values of the sparse array * @param dims - dims of the sparse array - * @param dimCount - number of the dimension + * @param dimCount - number of the dimensions, necessary as values and dims are dynamically created arrays, + * and may be over-allocated * @return BytesRef */ static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { @@ -66,9 +67,12 @@ static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { /** * Decodes the first part of BytesRef into sparse vector dimensions - * @param vectorBR - vector decoded in BytesRef + * @param vectorBR - sparse vector encoded in BytesRef */ - static int[] decodeSparseVectorDims(BytesRef vectorBR) { + public static int[] decodeSparseVectorDims(BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalStateException("A document doesn't have a value for a vector field!"); + } int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); int[] dims = new int[dimCount]; int offset = vectorBR.offset; @@ -81,9 +85,12 @@ static int[] decodeSparseVectorDims(BytesRef vectorBR) { /** * Decodes the second part of the BytesRef into sparse vector values - * @param vectorBR - vector decoded in BytesRef + * @param vectorBR - sparse vector encoded in BytesRef */ - static float[] decodeSparseVector(BytesRef vectorBR) { + public static float[] decodeSparseVector(BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalStateException("A document doesn't have a value for a vector field!"); + } int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded float[] vector = new float[dimCount]; @@ -100,10 +107,14 @@ static float[] decodeSparseVector(BytesRef vectorBR) { /** - Sort dimensions in the ascending order and - sort values in the same order as their corresponding dimensions - **/ - static void sortSparseDimsValues(int[] dims, float[] values, int n) { + * Sorts dimensions in the ascending order and + * sorts values in the same order as their corresponding dimensions + * + * @param dims - dimensions of the sparse query vector + * @param values - values for the sparse query vector + * @param n - number of dimensions + */ + public static void sortSparseDimsValues(int[] dims, float[] values, int n) { new InPlaceMergeSorter() { @Override public int compare(int i, int j) { @@ -123,8 +134,14 @@ public void swap(int i, int j) { }.sort(0, n); } - // Decodes a BytesRef into an array of floats - static float[] decodeDenseVector(BytesRef vectorBR) { + /** + * Decodes a BytesRef into an array of floats + * @param vectorBR - dense vector encoded in BytesRef + */ + public static float[] decodeDenseVector(BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalStateException("A document doesn't have a value for a vector field!"); + } int dimCount = vectorBR.length / INT_BYTES; float[] vector = new float[dimCount]; int offset = vectorBR.offset; diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java new file mode 100644 index 0000000000000..f463135d69f71 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java @@ -0,0 +1,42 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + + +import org.elasticsearch.painless.spi.PainlessExtension; +import org.elasticsearch.painless.spi.Whitelist; +import org.elasticsearch.painless.spi.WhitelistLoader; +import org.elasticsearch.script.ScoreScript; +import org.elasticsearch.script.ScriptContext; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class DocValuesWhitelistExtension implements PainlessExtension { + + private static final Whitelist WHITELIST = + WhitelistLoader.loadFromResourceFiles(DocValuesWhitelistExtension.class, "docvalues_whitelist.txt"); + + @Override + public Map, List> getContextWhitelists() { + return Collections.singletonMap(ScoreScript.CONTEXT, Collections.singletonList(WHITELIST)); + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java new file mode 100644 index 0000000000000..a889dfc5743f6 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java @@ -0,0 +1,202 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.VectorEncoderDecoder; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.index.mapper.VectorEncoderDecoder.sortSparseDimsValues; + +public class ScoreScriptUtils { + + //**************FUNCTIONS FOR DENSE VECTORS + + /** + * Calculate a dot product between two dense vectors + * + * @param queryVector the query vector parsed as {@code List} from json + * @param docVectorBR BytesRef representing encoded document vector + */ + public static float dotProduct(List queryVector, BytesRef docVectorBR){ + float[] docVector = VectorEncoderDecoder.decodeDenseVector(docVectorBR); + return intDotProduct(queryVector, docVector); + } + + /** + * Calculate cosine similarity between two dense vectors + * + * CosineSimilarity is implemented as a class to use + * painless script caching to calculate queryVectorMagnitude + * only once per script execution for all documents. + * A user will call `cosineSimilarity(params.queryVector, doc['my_vector'].getValue())` + */ + public static final class CosineSimilarity { + final float queryVectorMagnitude; + List queryVector; + + // calculate queryVectorMagnitude once per query execution + public CosineSimilarity(List queryVector) { + this.queryVector = queryVector; + float floatValue; + float dotProduct = 0f; + for (Number value : queryVector) { + floatValue = value.floatValue(); + dotProduct += floatValue * floatValue; + } + this.queryVectorMagnitude = (float) Math.sqrt(dotProduct); + } + + public float cosineSimilarity(BytesRef docVectorBR) { + float[] docVector = VectorEncoderDecoder.decodeDenseVector(docVectorBR); + + // calculate docVector magnitude + float dotProduct = 0f; + for (int dim = 0; dim < docVector.length; dim++) { + dotProduct += docVector[dim] * docVector[dim]; + } + final float docVectorMagnitude = (float) Math.sqrt(dotProduct); + + float docQueryDotProduct = intDotProduct(queryVector, docVector); + return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); + } + } + + private static float intDotProduct(List v1, float[] v2){ + int dims = Math.min(v1.size(), v2.length); + float v1v2DotProduct = 0f; + int dim = 0; + Iterator v1Iter = v1.iterator(); + while(dim < dims) { + v1v2DotProduct += v1Iter.next().floatValue() * v2[dim]; + dim++; + } + return v1v2DotProduct; + } + + + //**************FUNCTIONS FOR SPARSE VECTORS + + /** + * Calculate a dot product between two sparse vectors + * + * DotProductSparse is implemented as a class to use + * painless script caching to prepare queryVector + * only once per script execution for all documents. + * A user will call `dotProductSparse(params.queryVector, doc['my_vector'].getValue())` + */ + public static final class DotProductSparse { + float[] queryValues; + int[] queryDims; + + // prepare queryVector once per script execution + // queryVector represents a map of dimensions to values + public DotProductSparse(Map queryVector) { + //break vector into two arrays dims and values + int n = queryVector.size(); + queryDims = new int[n]; + queryValues = new float[n]; + int i = 0; + for (Map.Entry dimValue : queryVector.entrySet()) { + queryDims[i] = Integer.parseInt(dimValue.getKey()); + queryValues[i] = dimValue.getValue().floatValue(); + i++; + } + // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions + sortSparseDimsValues(queryDims, queryValues, n); + } + + public float dotProductSparse(BytesRef docVectorBR) { + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docVectorBR); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(docVectorBR); + return intDotProductSparse(queryValues, queryDims, docValues, docDims); + } + } + + /** + * Calculate cosine similarity between two sparse vectors + * + * CosineSimilaritySparse is implemented as a class to use + * painless script caching to prepare queryVector and calculate queryVectorMagnitude + * only once per script execution for all documents. + * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'].getValue())` + */ + public static final class CosineSimilaritySparse { + float[] queryValues; + int[] queryDims; + float queryVectorMagnitude; + + // prepare queryVector once per script execution + public CosineSimilaritySparse(Map queryVector) { + //break vector into two arrays dims and values + int n = queryVector.size(); + queryValues = new float[n]; + queryDims = new int[n]; + float dotProduct = 0f; + int i = 0; + for (Map.Entry dimValue : queryVector.entrySet()) { + queryDims[i] = Integer.parseInt(dimValue.getKey()); + queryValues[i] = dimValue.getValue().floatValue(); + dotProduct += queryValues[i] * queryValues[i]; + i++; + } + this.queryVectorMagnitude = (float) Math.sqrt(dotProduct); + // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions + sortSparseDimsValues(queryDims, queryValues, n); + } + + public float cosineSimilaritySparse(BytesRef docVectorBR) { + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docVectorBR); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(docVectorBR); + + // calculate docVector magnitude + float dotProduct = 0f; + for (float value : docValues) { + dotProduct += value * value; + } + final float docVectorMagnitude = (float) Math.sqrt(dotProduct); + + float docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); + return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); + } + } + + private static float intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) { + float v1v2DotProduct = 0f; + int v1Index = 0; + int v2Index = 0; + // find common dimensions among vectors v1 and v2 and calculate dotProduct based on common dimensions + while (v1Index < v1Values.length && v2Index < v2Values.length) { + if (v1Dims[v1Index] == v2Dims[v2Index]) { + v1v2DotProduct += v1Values[v1Index] * v2Values[v2Index]; + v1Index++; + v2Index++; + } else if (v1Dims[v1Index] > v2Dims[v2Index]) { + v2Index++; + } else { + v1Index++; + } + } + return v1v2DotProduct; + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java new file mode 100644 index 0000000000000..df654ae7f5113 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.fielddata.AtomicFieldData; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.index.fielddata.SortedBinaryDocValues; + +import java.util.Collection; +import java.util.Collections; + +final class VectorDVAtomicFieldData implements AtomicFieldData { + + private final BinaryDocValues values; + + VectorDVAtomicFieldData(BinaryDocValues values) { + super(); + this.values = values; + } + + @Override + public long ramBytesUsed() { + return 0; // not exposed by Lucene + } + + @Override + public Collection getChildResources() { + return Collections.emptyList(); + } + + @Override + public SortedBinaryDocValues getBytesValues() { + return null; + } + + @Override + public ScriptDocValues getScriptValues() { + return new VectorScriptDocValues(values); + } + + @Override + public void close() { + // no-op + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java new file mode 100644 index 0000000000000..723342f6a1737 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java @@ -0,0 +1,73 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.SortField; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; +import org.elasticsearch.index.fielddata.IndexFieldDataCache; +import org.elasticsearch.index.fielddata.plain.DocValuesIndexFieldData; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.search.MultiValueMode; + +import java.io.IOException; + +public class VectorDVIndexFieldData extends DocValuesIndexFieldData implements IndexFieldData { + + public VectorDVIndexFieldData(Index index, String fieldName) { + super(index, fieldName); + } + + @Override + public SortField sortField(@Nullable Object missingValue, MultiValueMode sortMode, Nested nested, boolean reverse) { + throw new IllegalArgumentException("can't sort on the vector field"); + } + + @Override + public VectorDVAtomicFieldData load(LeafReaderContext context) { + try { + return new VectorDVAtomicFieldData(DocValues.getBinary(context.reader(), fieldName)); + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } + } + + @Override + public VectorDVAtomicFieldData loadDirect(LeafReaderContext context) throws Exception { + return load(context); + } + + public static class Builder implements IndexFieldData.Builder { + @Override + public IndexFieldData build(IndexSettings indexSettings, MappedFieldType fieldType, IndexFieldDataCache cache, + CircuitBreakerService breakerService, MapperService mapperService) { + final String fieldName = fieldType.name(); + return new VectorDVIndexFieldData(indexSettings.getIndex(), fieldName); + } + + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java new file mode 100644 index 0000000000000..cda8d6f349544 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.fielddata.ScriptDocValues; + +import java.io.IOException; + +/** + * VectorScriptDocValues represents docValues for dense and sparse vector fields + */ +public final class VectorScriptDocValues extends ScriptDocValues { + + private final BinaryDocValues in; + private BytesRef value; + + VectorScriptDocValues(BinaryDocValues in) { + this.in = in; + } + + @Override + public void setNextDocId(int docId) throws IOException { + if (in.advanceExact(docId)) { + value = in.binaryValue(); + } else { + value = null; + } + } + + public BytesRef getValue() { + return value; + } + + @Override + public BytesRef get(int index) { + throw new UnsupportedOperationException("this operation is not supported on vector doc values"); + } + + @Override + public int size() { + throw new UnsupportedOperationException("this operation is not supported on vector doc values"); + } + +} diff --git a/modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension b/modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension new file mode 100644 index 0000000000000..f4cc27a362e51 --- /dev/null +++ b/modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension @@ -0,0 +1 @@ +org.elasticsearch.index.query.DocValuesWhitelistExtension \ No newline at end of file diff --git a/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt new file mode 100644 index 0000000000000..ea5503b80d83a --- /dev/null +++ b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt @@ -0,0 +1,30 @@ +# +# Licensed to Elasticsearch under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +class org.elasticsearch.index.query.VectorScriptDocValues { + BytesRef getValue() +} + +static_import { + float cosineSimilarity(List, BytesRef) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilarity + float dotProduct(List, BytesRef) from_class org.elasticsearch.index.query.ScoreScriptUtils + float dotProductSparse(Map, BytesRef) bound_to org.elasticsearch.index.query.ScoreScriptUtils$DotProductSparse + float cosineSimilaritySparse(Map, BytesRef) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilaritySparse +} \ No newline at end of file diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml new file mode 100644 index 0000000000000..0378458ae1675 --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml @@ -0,0 +1,98 @@ +setup: + - skip: + version: " - 6.99.99" + reason: "dense_vector field was introduced in 7.0.0" + + - do: + indices.create: + index: test-index + body: + settings: + number_of_replicas: 0 + mappings: + _doc: + properties: + my_dense_vector: + type: dense_vector + - do: + index: + index: test-index + type: _doc + id: 1 + body: + my_dense_vector: [230.0, 300.33, -34.8988, 15.555, -200.0] + + - do: + index: + index: test-index + type: _doc + id: 2 + body: + my_dense_vector: [-0.5, 100.0, -13, 14.8, -156.0] + + - do: + index: + index: test-index + type: _doc + id: 3 + body: + my_dense_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - do: + indices.refresh: {} + +--- +"Dot Product": + - do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, doc['my_dense_vector'].value)" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "1"} + - gte: {hits.hits.0._score: 65425.62} + - lte: {hits.hits.0._score: 65425.63} + + - match: {hits.hits.1._id: "3"} + - gte: {hits.hits.1._score: 37111.98} + - lte: {hits.hits.1._score: 37111.99} + + - match: {hits.hits.2._id: "2"} + - gte: {hits.hits.2._score: 35853.78} + - lte: {hits.hits.2._score: 35853.79} + +--- +"Cosine Similarity": + - do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "3"} + - gte: {hits.hits.0._score: 0.999} + - lte: {hits.hits.0._score: 1.001} + + - match: {hits.hits.1._id: "2"} + - gte: {hits.hits.1._score: 0.998} + - lte: {hits.hits.1._score: 1.0} + + - match: {hits.hits.2._id: "1"} + - gte: {hits.hits.2._score: 0.78} + - lte: {hits.hits.2._score: 0.791} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml deleted file mode 100644 index ef31d0f45e240..0000000000000 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml +++ /dev/null @@ -1,29 +0,0 @@ -setup: - - skip: - version: " - 6.99.99" - reason: "dense_vector field was introduced in 7.0.0" - - - do: - indices.create: - index: test-index - body: - settings: - number_of_replicas: 0 - mappings: - _doc: - properties: - my_dense_vector: - type: dense_vector - - ---- -"Indexing": - - do: - index: - index: test-index - type: _doc - id: 1 - body: - my_dense_vector: [1.5, -10, 3455, 345452.4545] - - - match: { result: created } diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml new file mode 100644 index 0000000000000..875a2f20b4558 --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml @@ -0,0 +1,133 @@ +setup: + - skip: + version: " - 6.99.99" + reason: "dense_vector field was introduced in 7.0.0" + + - do: + indices.create: + index: test-index + body: + settings: + number_of_replicas: 0 + mappings: + _doc: + properties: + my_dense_vector: + type: dense_vector + + +--- +"Vectors of different dimensions and data types": +# document vectors of different dimensions + - do: + index: + index: test-index + type: _doc + id: 1 + body: + my_dense_vector: [10] + + - do: + index: + index: test-index + type: _doc + id: 2 + body: + my_dense_vector: [10, 10.5] + + - do: + index: + index: test-index + type: _doc + id: 3 + body: + my_dense_vector: [10, 10.5, 100.5] + + - do: + indices.refresh: {} + +# query vector of type integer + - do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + params: + query_vector: [10] + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +# query vector of type double + - do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + params: + query_vector: [10.0] + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +--- +"Documents missing vector field": +- do: + index: + index: test-index + type: _doc + id: 1 + body: + my_dense_vector: [10] + +- do: + index: + index: test-index + type: _doc + id: 2 + body: + some_other_field: "random_value" + +- do: + indices.refresh: {} + +- do: + catch: /A document doesn't have a value for a vector field!/ + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + params: + query_vector: [10.0] + +- do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['my_dense_vector'].value == null ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + params: + query_vector: [10.0] + +- match: {hits.total: 2} +- match: {hits.hits.0._id: "1"} +- match: {hits.hits.1._id: "2"} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml new file mode 100644 index 0000000000000..f927cf21830ce --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml @@ -0,0 +1,98 @@ +setup: + - skip: + version: " - 6.99.99" + reason: "sparse_vector field was introduced in 7.0.0" + + - do: + indices.create: + index: test-index + body: + settings: + number_of_replicas: 0 + mappings: + _doc: + properties: + my_sparse_vector: + type: sparse_vector + - do: + index: + index: test-index + type: _doc + id: 1 + body: + my_sparse_vector: {"2": 230.0, "10" : 300.33, "50": -34.8988, "113": 15.555, "4545": -200.0} + + - do: + index: + index: test-index + type: _doc + id: 2 + body: + my_sparse_vector: {"2": -0.5, "10" : 100.0, "50": -13, "113": 14.8, "4545": -156.0} + + - do: + index: + index: test-index + type: _doc + id: 3 + body: + my_sparse_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + + - do: + indices.refresh: {} + +--- +"Dot Product": +- do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + +- match: {hits.total: 3} + +- match: {hits.hits.0._id: "1"} +- gte: {hits.hits.0._score: 65425.62} +- lte: {hits.hits.0._score: 65425.63} + +- match: {hits.hits.1._id: "3"} +- gte: {hits.hits.1._score: 37111.98} +- lte: {hits.hits.1._score: 37111.99} + +- match: {hits.hits.2._id: "2"} +- gte: {hits.hits.2._score: 35853.78} +- lte: {hits.hits.2._score: 35853.79} + +--- +"Cosine Similarity": +- do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + +- match: {hits.total: 3} + +- match: {hits.hits.0._id: "3"} +- gte: {hits.hits.0._score: 0.999} +- lte: {hits.hits.0._score: 1.001} + +- match: {hits.hits.1._id: "2"} +- gte: {hits.hits.1._score: 0.998} +- lte: {hits.hits.1._score: 1.0} + +- match: {hits.hits.2._id: "1"} +- gte: {hits.hits.2._score: 0.78} +- lte: {hits.hits.2._score: 0.791} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml deleted file mode 100644 index 87d599e9cb078..0000000000000 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml +++ /dev/null @@ -1,29 +0,0 @@ -setup: - - skip: - version: " - 6.99.99" - reason: "sparse_vector field was introduced in 7.0.0" - - - do: - indices.create: - index: test-index - body: - settings: - number_of_replicas: 0 - mappings: - _doc: - properties: - my_sparse_vector: - type: sparse_vector - - ---- -"Indexing": - - do: - index: - index: test-index - type: _doc - id: 1 - body: - my_sparse_vector: { "50" : 1.8, "2" : -0.4, "10" : 1000.3, "4545" : -0.00004} - - - match: { result: created } diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml new file mode 100644 index 0000000000000..157a564fa0e94 --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml @@ -0,0 +1,184 @@ +setup: + - skip: + version: " - 6.99.99" + reason: "sparse_vector field was introduced in 7.0.0" + + - do: + indices.create: + index: test-index + body: + settings: + number_of_replicas: 0 + mappings: + _doc: + properties: + my_sparse_vector: + type: sparse_vector + + +--- +"Vectors of different dimensions and data types": +# document vectors of different dimensions + - do: + index: + index: test-index + type: _doc + id: 1 + body: + my_sparse_vector: {"1": 10} + + - do: + index: + index: test-index + type: _doc + id: 2 + body: + my_sparse_vector: {"1": 10, "10" : 10.5} + + - do: + index: + index: test-index + type: _doc + id: 3 + body: + my_sparse_vector: {"1": 10, "10" : 10.5, "100": 100.5} + + - do: + indices.refresh: {} + +# query vector of type integer + - do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"1": 10} + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +# query vector of type double + - do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"1": 10.0} + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +--- +"Documents missing vector field": +- do: + index: + index: test-index + type: _doc + id: 1 + body: + my_sparse_vector: {"1": 10} + +- do: + index: + index: test-index + type: _doc + id: 2 + body: + some_other_field: "random_value" + +- do: + indices.refresh: {} + +- do: + catch: /A document doesn't have a value for a vector field!/ + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"1": 10.0} + +- do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['my_sparse_vector'].value == null ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"1": 10.0} + +- match: {hits.total: 2} +- match: {hits.hits.0._id: "1"} +- match: {hits.hits.1._id: "2"} + +--- +"Dimensions can be sorted differently": +# All the documents' and query's vectors are the same, and should return cosineSimilarity equal to 1 +- do: + index: + index: test-index + type: _doc + id: 1 + body: + my_sparse_vector: {"2": 230.0, "11" : 300.33, "12": -34.8988, "30": 15.555, "100": -200.0} + +- do: + index: + index: test-index + type: _doc + id: 2 + body: + my_sparse_vector: {"100": -200.0, "12": -34.8988, "11" : 300.33, "113": 15.555, "2": 230.0} + +- do: + index: + index: test-index + type: _doc + id: 3 + body: + my_sparse_vector: {"100": -200.0, "30": 15.555, "12": -34.8988, "11" : 300.33, "2": 230.0} + +- do: + indices.refresh: {} + +- do: + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + params: + query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} + +- match: {hits.total: 3} + +- gte: {hits.hits.0._score: 0.99} +- lte: {hits.hits.0._score: 1.001} +- gte: {hits.hits.1._score: 0.99} +- lte: {hits.hits.1._score: 1.001} +- gte: {hits.hits.2._score: 0.99} +- lte: {hits.hits.2._score: 1.001} diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java index e35fc4450509a..f958cad5d381d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java @@ -62,6 +62,14 @@ protected void doAssert(Object actualValue, Object expectedValue) { assertThat("expected value of [" + getField() + "] is not comparable (got [" + expectedValue.getClass() + "])", expectedValue, instanceOf(Comparable.class)); try { + // make numbers comparable with each other: Float 1.0 can be compared to Double 1.0 + if (actualValue.getClass().equals(safeClass(expectedValue)) == false) { + if (actualValue instanceof Number && expectedValue instanceof Number) { + assertThat(errorMessage(), (Comparable) ((Number) actualValue).doubleValue(), + greaterThanOrEqualTo((Comparable) ((Number) expectedValue).doubleValue())); + return; + } + } assertThat(errorMessage(), (Comparable) actualValue, greaterThanOrEqualTo((Comparable) expectedValue)); } catch (ClassCastException e) { fail("cast error while checking (" + errorMessage() + "): " + e); diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java index 0b11b304ce60c..ad840dabe840c 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java @@ -62,6 +62,14 @@ protected void doAssert(Object actualValue, Object expectedValue) { assertThat("expected value of [" + getField() + "] is not comparable (got [" + expectedValue.getClass() + "])", expectedValue, instanceOf(Comparable.class)); try { + // make numbers comparable with each other: Float 1.0 can be compared to Double 1.0 + if (actualValue.getClass().equals(safeClass(expectedValue)) == false) { + if (actualValue instanceof Number && expectedValue instanceof Number) { + assertThat(errorMessage(), (Comparable) ((Number) actualValue).doubleValue(), + lessThanOrEqualTo((Comparable) ((Number) expectedValue).doubleValue())); + return; + } + } assertThat(errorMessage(), (Comparable) actualValue, lessThanOrEqualTo((Comparable) expectedValue)); } catch (ClassCastException e) { fail("cast error while checking (" + errorMessage() + "): " + e); From 7075c03443b2aae1f4a3809235e63110c5bc5fce Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Wed, 30 Jan 2019 17:08:08 -0500 Subject: [PATCH 2/8] Address Adrien's comments --- .../query-dsl/script-score-query.asciidoc | 35 ++------- .../index/query/ScoreScriptUtils.java | 73 ++++++++++--------- .../index/query/VectorScriptDocValues.java | 7 +- .../index/query/docvalues_whitelist.txt | 9 +-- .../test/dense-vector/10_basic.yml | 9 ++- .../test/dense-vector/20_special_cases.yml | 30 ++++---- .../test/sparse-vector/10_basic.yml | 9 ++- .../test/sparse-vector/20_special_cases.yml | 34 ++++----- .../section/GreaterThanEqualToAssertion.java | 8 -- .../section/LessThanOrEqualToAssertion.java | 8 -- 10 files changed, 95 insertions(+), 127 deletions(-) diff --git a/docs/reference/query-dsl/script-score-query.asciidoc b/docs/reference/query-dsl/script-score-query.asciidoc index d27a46f5b5051..f5865d6774a20 100644 --- a/docs/reference/query-dsl/script-score-query.asciidoc +++ b/docs/reference/query-dsl/script-score-query.asciidoc @@ -92,7 +92,7 @@ cosine similarity between a given query vector and document vectors. "match_all": {} }, "script": { - "source": "cosineSimilarity(params.queryVector, doc['my_dense_vector'].value)", + "source": "cosineSimilarity(params.queryVector, doc['my_dense_vector'])", "params": { "queryVector": [4, 3.4, -1.2] <1> } @@ -116,7 +116,7 @@ between a given query vector and document vectors. "match_all": {} }, "script": { - "source": "cosineSimilaritySparse(params.queryVector, doc['my_sparse_vector'].value)", + "source": "cosineSimilaritySparse(params.queryVector, doc['my_sparse_vector'])", "params": { "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} } @@ -139,7 +139,7 @@ dot product between a given query vector and document vectors. "match_all": {} }, "script": { - "source": "dotProduct(params.queryVector, doc['my_dense_vector'].value)", + "source": "dotProduct(params.queryVector, doc['my_dense_vector'])", "params": { "queryVector": [4, 3.4, -1.2] } @@ -162,7 +162,7 @@ between a given query vector and document vectors. "match_all": {} }, "script": { - "source": "dotProductSparse(params.queryVector, doc['my_sparse_vector'].value)", + "source": "dotProductSparse(params.queryVector, doc['my_sparse_vector'])", "params": { "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} } @@ -173,31 +173,8 @@ between a given query vector and document vectors. -------------------------------------------------- // NOTCONSOLE - -Be aware, if a document doesn't have a value for a vector field on which -a distance function is executed, an error will be thrown. You can -guard against missing values as shown below, and provide your own -custom score for this case. - -[source,js] --------------------------------------------------- -{ - "query": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": "doc['my_dense_vector'].value == null ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)", - "params": { - "queryVector": [4, 3.4, -1.2] - } - } - } - } -} --------------------------------------------------- -// NOTCONSOLE +NOTE: If a document doesn't have a value for a vector field on which +a distance function is executed, 0 will be returned as a result. [[random-functions]] diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java index a889dfc5743f6..9c84b219e96f4 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java @@ -19,7 +19,6 @@ package org.elasticsearch.index.query; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.mapper.VectorEncoderDecoder; import java.util.Iterator; @@ -33,58 +32,60 @@ public class ScoreScriptUtils { //**************FUNCTIONS FOR DENSE VECTORS /** - * Calculate a dot product between two dense vectors + * Calculate a dot product between a query's dense vector and documents' dense vectors * * @param queryVector the query vector parsed as {@code List} from json - * @param docVectorBR BytesRef representing encoded document vector + * @param dvs VectorScriptDocValues representing encoded documents' vectors */ - public static float dotProduct(List queryVector, BytesRef docVectorBR){ - float[] docVector = VectorEncoderDecoder.decodeDenseVector(docVectorBR); + public static double dotProduct(List queryVector, VectorScriptDocValues dvs){ + if (dvs.getValue() == null) return 0; + float[] docVector = VectorEncoderDecoder.decodeDenseVector(dvs.getValue()); return intDotProduct(queryVector, docVector); } /** - * Calculate cosine similarity between two dense vectors + * Calculate cosine similarity between a query's dense vector and documents' dense vectors * * CosineSimilarity is implemented as a class to use * painless script caching to calculate queryVectorMagnitude * only once per script execution for all documents. - * A user will call `cosineSimilarity(params.queryVector, doc['my_vector'].getValue())` + * A user will call `cosineSimilarity(params.queryVector, doc['my_vector'])` */ public static final class CosineSimilarity { - final float queryVectorMagnitude; + final double queryVectorMagnitude; List queryVector; // calculate queryVectorMagnitude once per query execution public CosineSimilarity(List queryVector) { this.queryVector = queryVector; float floatValue; - float dotProduct = 0f; + double dotProduct = 0; for (Number value : queryVector) { floatValue = value.floatValue(); dotProduct += floatValue * floatValue; } - this.queryVectorMagnitude = (float) Math.sqrt(dotProduct); + this.queryVectorMagnitude = Math.sqrt(dotProduct); } - public float cosineSimilarity(BytesRef docVectorBR) { - float[] docVector = VectorEncoderDecoder.decodeDenseVector(docVectorBR); + public double cosineSimilarity(VectorScriptDocValues dvs) { + if (dvs.getValue() == null) return 0; + float[] docVector = VectorEncoderDecoder.decodeDenseVector(dvs.getValue()); // calculate docVector magnitude - float dotProduct = 0f; + double dotProduct = 0f; for (int dim = 0; dim < docVector.length; dim++) { dotProduct += docVector[dim] * docVector[dim]; } - final float docVectorMagnitude = (float) Math.sqrt(dotProduct); + final double docVectorMagnitude = Math.sqrt(dotProduct); - float docQueryDotProduct = intDotProduct(queryVector, docVector); + double docQueryDotProduct = intDotProduct(queryVector, docVector); return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); } } - private static float intDotProduct(List v1, float[] v2){ + private static double intDotProduct(List v1, float[] v2){ int dims = Math.min(v1.size(), v2.length); - float v1v2DotProduct = 0f; + double v1v2DotProduct = 0; int dim = 0; Iterator v1Iter = v1.iterator(); while(dim < dims) { @@ -98,12 +99,12 @@ private static float intDotProduct(List v1, float[] v2){ //**************FUNCTIONS FOR SPARSE VECTORS /** - * Calculate a dot product between two sparse vectors + * Calculate a dot product between a query's sparse vector and documents' sparse vectors * * DotProductSparse is implemented as a class to use * painless script caching to prepare queryVector * only once per script execution for all documents. - * A user will call `dotProductSparse(params.queryVector, doc['my_vector'].getValue())` + * A user will call `dotProductSparse(params.queryVector, doc['my_vector'])` */ public static final class DotProductSparse { float[] queryValues; @@ -126,25 +127,26 @@ public DotProductSparse(Map queryVector) { sortSparseDimsValues(queryDims, queryValues, n); } - public float dotProductSparse(BytesRef docVectorBR) { - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docVectorBR); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(docVectorBR); + public double dotProductSparse(VectorScriptDocValues dvs) { + if (dvs.getValue() == null) return 0; + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(dvs.getValue()); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(dvs.getValue()); return intDotProductSparse(queryValues, queryDims, docValues, docDims); } } /** - * Calculate cosine similarity between two sparse vectors + * Calculate cosine similarity between a query's sparse vector and documents' sparse vectors * * CosineSimilaritySparse is implemented as a class to use * painless script caching to prepare queryVector and calculate queryVectorMagnitude * only once per script execution for all documents. - * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'].getValue())` + * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'])` */ public static final class CosineSimilaritySparse { float[] queryValues; int[] queryDims; - float queryVectorMagnitude; + double queryVectorMagnitude; // prepare queryVector once per script execution public CosineSimilaritySparse(Map queryVector) { @@ -152,7 +154,7 @@ public CosineSimilaritySparse(Map queryVector) { int n = queryVector.size(); queryValues = new float[n]; queryDims = new int[n]; - float dotProduct = 0f; + double dotProduct = 0; int i = 0; for (Map.Entry dimValue : queryVector.entrySet()) { queryDims[i] = Integer.parseInt(dimValue.getKey()); @@ -160,29 +162,30 @@ public CosineSimilaritySparse(Map queryVector) { dotProduct += queryValues[i] * queryValues[i]; i++; } - this.queryVectorMagnitude = (float) Math.sqrt(dotProduct); + this.queryVectorMagnitude = Math.sqrt(dotProduct); // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions sortSparseDimsValues(queryDims, queryValues, n); } - public float cosineSimilaritySparse(BytesRef docVectorBR) { - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docVectorBR); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(docVectorBR); + public double cosineSimilaritySparse(VectorScriptDocValues dvs) { + if (dvs.getValue() == null) return 0; + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(dvs.getValue()); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(dvs.getValue()); // calculate docVector magnitude - float dotProduct = 0f; + double dotProduct = 0; for (float value : docValues) { dotProduct += value * value; } - final float docVectorMagnitude = (float) Math.sqrt(dotProduct); + final double docVectorMagnitude = Math.sqrt(dotProduct); - float docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); + double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); } } - private static float intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) { - float v1v2DotProduct = 0f; + private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) { + double v1v2DotProduct = 0; int v1Index = 0; int v2Index = 0; // find common dimensions among vectors v1 and v2 and calculate dotProduct based on common dimensions diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java index cda8d6f349544..f83385f987611 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java @@ -46,18 +46,19 @@ public void setNextDocId(int docId) throws IOException { } } - public BytesRef getValue() { + // package private access only for {@link ScoreScriptUtils} + BytesRef getValue() { return value; } @Override public BytesRef get(int index) { - throw new UnsupportedOperationException("this operation is not supported on vector doc values"); + throw new UnsupportedOperationException("this operation is not supported on the doc values of the vector field"); } @Override public int size() { - throw new UnsupportedOperationException("this operation is not supported on vector doc values"); + throw new UnsupportedOperationException("this operation is not supported on the doc values of the vector field"); } } diff --git a/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt index ea5503b80d83a..62874ae46a7f8 100644 --- a/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt +++ b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt @@ -19,12 +19,11 @@ class org.elasticsearch.index.query.VectorScriptDocValues { - BytesRef getValue() } static_import { - float cosineSimilarity(List, BytesRef) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilarity - float dotProduct(List, BytesRef) from_class org.elasticsearch.index.query.ScoreScriptUtils - float dotProductSparse(Map, BytesRef) bound_to org.elasticsearch.index.query.ScoreScriptUtils$DotProductSparse - float cosineSimilaritySparse(Map, BytesRef) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilaritySparse + double cosineSimilarity(List, VectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(List, VectorScriptDocValues) from_class org.elasticsearch.index.query.ScoreScriptUtils + double dotProductSparse(Map, VectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(Map, VectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilaritySparse } \ No newline at end of file diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml index 0378458ae1675..fd71023b94728 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml @@ -1,5 +1,6 @@ setup: - skip: + features: headers version: " - 6.99.99" reason: "dense_vector field was introduced in 7.0.0" @@ -44,6 +45,8 @@ setup: --- "Dot Product": - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -51,7 +54,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'].value)" + source: "dotProduct(params.query_vector, doc['my_dense_vector'])" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -72,6 +75,8 @@ setup: --- "Cosine Similarity": - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -79,7 +84,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml index 875a2f20b4558..7185510828eb8 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml @@ -1,5 +1,6 @@ setup: - skip: + features: headers version: " - 6.99.99" reason: "dense_vector field was introduced in 7.0.0" @@ -48,6 +49,8 @@ setup: # query vector of type integer - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -55,7 +58,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: query_vector: [10] @@ -66,6 +69,8 @@ setup: # query vector of type double - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -73,7 +78,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: query_vector: [10.0] @@ -83,7 +88,7 @@ setup: - match: {hits.hits.2._id: "3"} --- -"Documents missing vector field": +"Distance functions for documents missing vector field should return 0": - do: index: index: test-index @@ -104,7 +109,8 @@ setup: indices.refresh: {} - do: - catch: /A document doesn't have a value for a vector field!/ + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -112,22 +118,12 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" - params: - query_vector: [10.0] - -- do: - search: - rest_total_hits_as_int: true - body: - query: - script_score: - query: {match_all: {} } - script: - source: "doc['my_dense_vector'].value == null ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'].value)" + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: query_vector: [10.0] - match: {hits.total: 2} - match: {hits.hits.0._id: "1"} - match: {hits.hits.1._id: "2"} +- match: {hits.hits.1._score: 0.0} + diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml index f927cf21830ce..41c480fe4275d 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml @@ -1,5 +1,6 @@ setup: - skip: + features: headers version: " - 6.99.99" reason: "sparse_vector field was introduced in 7.0.0" @@ -44,6 +45,8 @@ setup: --- "Dot Product": - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -51,7 +54,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'].value)" + source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -72,6 +75,8 @@ setup: --- "Cosine Similarity": - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -79,7 +84,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" params: query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml index 157a564fa0e94..612135a91eb78 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml @@ -1,5 +1,6 @@ setup: - skip: + features: headers version: " - 6.99.99" reason: "sparse_vector field was introduced in 7.0.0" @@ -48,6 +49,8 @@ setup: # query vector of type integer - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -55,7 +58,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" params: query_vector: {"1": 10} @@ -66,6 +69,8 @@ setup: # query vector of type double - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -73,7 +78,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" params: query_vector: {"1": 10.0} @@ -83,7 +88,7 @@ setup: - match: {hits.hits.2._id: "3"} --- -"Documents missing vector field": +"Distance functions for documents missing vector field should return 0": - do: index: index: test-index @@ -104,7 +109,8 @@ setup: indices.refresh: {} - do: - catch: /A document doesn't have a value for a vector field!/ + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -112,25 +118,15 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" - params: - query_vector: {"1": 10.0} - -- do: - search: - rest_total_hits_as_int: true - body: - query: - script_score: - query: {match_all: {} } - script: - source: "doc['my_sparse_vector'].value == null ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" params: query_vector: {"1": 10.0} - match: {hits.total: 2} - match: {hits.hits.0._id: "1"} - match: {hits.hits.1._id: "2"} +- match: {hits.hits.1._score: 0.0} + --- "Dimensions can be sorted differently": @@ -163,6 +159,8 @@ setup: indices.refresh: {} - do: + headers: + Content-Type: application/json search: rest_total_hits_as_int: true body: @@ -170,7 +168,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'].value)" + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" params: query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java index f958cad5d381d..e35fc4450509a 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java @@ -62,14 +62,6 @@ protected void doAssert(Object actualValue, Object expectedValue) { assertThat("expected value of [" + getField() + "] is not comparable (got [" + expectedValue.getClass() + "])", expectedValue, instanceOf(Comparable.class)); try { - // make numbers comparable with each other: Float 1.0 can be compared to Double 1.0 - if (actualValue.getClass().equals(safeClass(expectedValue)) == false) { - if (actualValue instanceof Number && expectedValue instanceof Number) { - assertThat(errorMessage(), (Comparable) ((Number) actualValue).doubleValue(), - greaterThanOrEqualTo((Comparable) ((Number) expectedValue).doubleValue())); - return; - } - } assertThat(errorMessage(), (Comparable) actualValue, greaterThanOrEqualTo((Comparable) expectedValue)); } catch (ClassCastException e) { fail("cast error while checking (" + errorMessage() + "): " + e); diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java index ad840dabe840c..0b11b304ce60c 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java @@ -62,14 +62,6 @@ protected void doAssert(Object actualValue, Object expectedValue) { assertThat("expected value of [" + getField() + "] is not comparable (got [" + expectedValue.getClass() + "])", expectedValue, instanceOf(Comparable.class)); try { - // make numbers comparable with each other: Float 1.0 can be compared to Double 1.0 - if (actualValue.getClass().equals(safeClass(expectedValue)) == false) { - if (actualValue instanceof Number && expectedValue instanceof Number) { - assertThat(errorMessage(), (Comparable) ((Number) actualValue).doubleValue(), - lessThanOrEqualTo((Comparable) ((Number) expectedValue).doubleValue())); - return; - } - } assertThat(errorMessage(), (Comparable) actualValue, lessThanOrEqualTo((Comparable) expectedValue)); } catch (ClassCastException e) { fail("cast error while checking (" + errorMessage() + "): " + e); From 3535e48ec82f04839b0eba83050e2f62821cf761 Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Wed, 30 Jan 2019 17:41:12 -0500 Subject: [PATCH 3/8] Removes typed calls from YAML REST tests --- .../rest-api-spec/test/dense-vector/10_basic.yml | 11 ++++------- .../test/dense-vector/20_special_cases.yml | 13 ++++--------- .../test/sparse-vector/10_basic.yml | 11 ++++------- .../test/sparse-vector/20_special_cases.yml | 16 ++++------------ 4 files changed, 16 insertions(+), 35 deletions(-) diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml index fd71023b94728..f2ad60589f044 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml @@ -6,19 +6,18 @@ setup: - do: indices.create: + include_type_name: false index: test-index body: settings: number_of_replicas: 0 mappings: - _doc: - properties: - my_dense_vector: - type: dense_vector + properties: + my_dense_vector: + type: dense_vector - do: index: index: test-index - type: _doc id: 1 body: my_dense_vector: [230.0, 300.33, -34.8988, 15.555, -200.0] @@ -26,7 +25,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: my_dense_vector: [-0.5, 100.0, -13, 14.8, -156.0] @@ -34,7 +32,6 @@ setup: - do: index: index: test-index - type: _doc id: 3 body: my_dense_vector: [0.5, 111.3, -13.0, 14.8, -156.0] diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml index 7185510828eb8..ce9e0052d5022 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml @@ -6,15 +6,15 @@ setup: - do: indices.create: + include_type_name: false index: test-index body: settings: number_of_replicas: 0 mappings: - _doc: - properties: - my_dense_vector: - type: dense_vector + properties: + my_dense_vector: + type: dense_vector --- @@ -23,7 +23,6 @@ setup: - do: index: index: test-index - type: _doc id: 1 body: my_dense_vector: [10] @@ -31,7 +30,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: my_dense_vector: [10, 10.5] @@ -39,7 +37,6 @@ setup: - do: index: index: test-index - type: _doc id: 3 body: my_dense_vector: [10, 10.5, 100.5] @@ -92,7 +89,6 @@ setup: - do: index: index: test-index - type: _doc id: 1 body: my_dense_vector: [10] @@ -100,7 +96,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: some_other_field: "random_value" diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml index 41c480fe4275d..6af9d0c2ec28f 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml @@ -6,19 +6,18 @@ setup: - do: indices.create: + include_type_name: false index: test-index body: settings: number_of_replicas: 0 mappings: - _doc: - properties: - my_sparse_vector: - type: sparse_vector + properties: + my_sparse_vector: + type: sparse_vector - do: index: index: test-index - type: _doc id: 1 body: my_sparse_vector: {"2": 230.0, "10" : 300.33, "50": -34.8988, "113": 15.555, "4545": -200.0} @@ -26,7 +25,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: my_sparse_vector: {"2": -0.5, "10" : 100.0, "50": -13, "113": 14.8, "4545": -156.0} @@ -34,7 +32,6 @@ setup: - do: index: index: test-index - type: _doc id: 3 body: my_sparse_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml index 612135a91eb78..5f6c5a3cffe9f 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml @@ -6,15 +6,15 @@ setup: - do: indices.create: + include_type_name: false index: test-index body: settings: number_of_replicas: 0 mappings: - _doc: - properties: - my_sparse_vector: - type: sparse_vector + properties: + my_sparse_vector: + type: sparse_vector --- @@ -23,7 +23,6 @@ setup: - do: index: index: test-index - type: _doc id: 1 body: my_sparse_vector: {"1": 10} @@ -31,7 +30,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: my_sparse_vector: {"1": 10, "10" : 10.5} @@ -39,7 +37,6 @@ setup: - do: index: index: test-index - type: _doc id: 3 body: my_sparse_vector: {"1": 10, "10" : 10.5, "100": 100.5} @@ -92,7 +89,6 @@ setup: - do: index: index: test-index - type: _doc id: 1 body: my_sparse_vector: {"1": 10} @@ -100,7 +96,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: some_other_field: "random_value" @@ -134,7 +129,6 @@ setup: - do: index: index: test-index - type: _doc id: 1 body: my_sparse_vector: {"2": 230.0, "11" : 300.33, "12": -34.8988, "30": 15.555, "100": -200.0} @@ -142,7 +136,6 @@ setup: - do: index: index: test-index - type: _doc id: 2 body: my_sparse_vector: {"100": -200.0, "12": -34.8988, "11" : 300.33, "113": 15.555, "2": 230.0} @@ -150,7 +143,6 @@ setup: - do: index: index: test-index - type: _doc id: 3 body: my_sparse_vector: {"100": -200.0, "30": 15.555, "12": -34.8988, "11" : 300.33, "2": 230.0} From ac0205cd759fccec2248206e60348bfd19384eac Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Tue, 5 Feb 2019 16:59:15 -0500 Subject: [PATCH 4/8] Address Adrien's comments 2 --- .../query-dsl/script-score-query.asciidoc | 21 +++--- .../index/mapper/VectorEncoderDecoder.java | 34 ++++++++- .../index/query/ScoreScriptUtils.java | 71 +++++++++++-------- .../index/query/VectorDVAtomicFieldData.java | 21 ++++-- .../index/query/VectorDVIndexFieldData.java | 8 +-- .../index/query/VectorScriptDocValues.java | 2 +- 6 files changed, 103 insertions(+), 54 deletions(-) diff --git a/docs/reference/query-dsl/script-score-query.asciidoc b/docs/reference/query-dsl/script-score-query.asciidoc index f5865d6774a20..ee68d3e40fe13 100644 --- a/docs/reference/query-dsl/script-score-query.asciidoc +++ b/docs/reference/query-dsl/script-score-query.asciidoc @@ -75,8 +75,8 @@ to be the most efficient by using the internal mechanisms. // NOTCONSOLE [[vector-functions]] -===== Distance functions for vector fields -These functions are used to calculate distances +===== Functions for vector fields +These functions are used for for <> and <> fields. @@ -94,7 +94,7 @@ cosine similarity between a given query vector and document vectors. "script": { "source": "cosineSimilarity(params.queryVector, doc['my_dense_vector'])", "params": { - "queryVector": [4, 3.4, -1.2] <1> + "queryVector": [4, 3.4, -0.2] <1> } } } @@ -102,7 +102,7 @@ cosine similarity between a given query vector and document vectors. } -------------------------------------------------- // NOTCONSOLE -<1> To take advantage of the script optimizations, supply a query vector in script parameters. +<1> To take advantage of the script optimizations, provide a query vector as a script parameter. Similarly, for sparse_vector fields, `cosineSimilaritySparse` calculates cosine similarity between a given query vector and document vectors. @@ -118,7 +118,7 @@ between a given query vector and document vectors. "script": { "source": "cosineSimilaritySparse(params.queryVector, doc['my_sparse_vector'])", "params": { - "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } } } @@ -141,7 +141,7 @@ dot product between a given query vector and document vectors. "script": { "source": "dotProduct(params.queryVector, doc['my_dense_vector'])", "params": { - "queryVector": [4, 3.4, -1.2] + "queryVector": [4, 3.4, -0.2] } } } @@ -164,7 +164,7 @@ between a given query vector and document vectors. "script": { "source": "dotProductSparse(params.queryVector, doc['my_sparse_vector'])", "params": { - "queryVector": {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } } } @@ -174,7 +174,12 @@ between a given query vector and document vectors. // NOTCONSOLE NOTE: If a document doesn't have a value for a vector field on which -a distance function is executed, 0 will be returned as a result. +a vector function is executed, 0 is returned as a result +for this document. + +NOTE: If a document's dense vector field has a number of dimensions +different from the query's vector, 0 is used for missing dimensions +in the calculations of vector functions. [[random-functions]] diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java index ece4574f0c4d2..a35ca16b01427 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java @@ -71,7 +71,7 @@ static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { */ public static int[] decodeSparseVectorDims(BytesRef vectorBR) { if (vectorBR == null) { - throw new IllegalStateException("A document doesn't have a value for a vector field!"); + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); int[] dims = new int[dimCount]; @@ -89,7 +89,7 @@ public static int[] decodeSparseVectorDims(BytesRef vectorBR) { */ public static float[] decodeSparseVector(BytesRef vectorBR) { if (vectorBR == null) { - throw new IllegalStateException("A document doesn't have a value for a vector field!"); + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded @@ -134,13 +134,41 @@ public void swap(int i, int j) { }.sort(0, n); } + /** + * Sorts dimensions in the ascending order and + * sorts values in the same order as their corresponding dimensions + * + * @param dims - dimensions of the sparse query vector + * @param values - values for the sparse query vector + * @param n - number of dimensions + */ + public static void sortSparseDimsDoubleValues(int[] dims, double[] values, int n) { + new InPlaceMergeSorter() { + @Override + public int compare(int i, int j) { + return Integer.compare(dims[i], dims[j]); + } + + @Override + public void swap(int i, int j) { + int tempDim = dims[i]; + dims[i] = dims[j]; + dims[j] = tempDim; + + double tempValue = values[j]; + values[j] = values[i]; + values[i] = tempValue; + } + }.sort(0, n); + } + /** * Decodes a BytesRef into an array of floats * @param vectorBR - dense vector encoded in BytesRef */ public static float[] decodeDenseVector(BytesRef vectorBR) { if (vectorBR == null) { - throw new IllegalStateException("A document doesn't have a value for a vector field!"); + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } int dimCount = vectorBR.length / INT_BYTES; float[] vector = new float[dimCount]; diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java index 9c84b219e96f4..98d4b8edec6c7 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java @@ -19,13 +19,14 @@ package org.elasticsearch.index.query; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.mapper.VectorEncoderDecoder; import java.util.Iterator; import java.util.List; import java.util.Map; -import static org.elasticsearch.index.mapper.VectorEncoderDecoder.sortSparseDimsValues; +import static org.elasticsearch.index.mapper.VectorEncoderDecoder.sortSparseDimsDoubleValues; public class ScoreScriptUtils { @@ -38,8 +39,9 @@ public class ScoreScriptUtils { * @param dvs VectorScriptDocValues representing encoded documents' vectors */ public static double dotProduct(List queryVector, VectorScriptDocValues dvs){ - if (dvs.getValue() == null) return 0; - float[] docVector = VectorEncoderDecoder.decodeDenseVector(dvs.getValue()); + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); return intDotProduct(queryVector, docVector); } @@ -58,23 +60,24 @@ public static final class CosineSimilarity { // calculate queryVectorMagnitude once per query execution public CosineSimilarity(List queryVector) { this.queryVector = queryVector; - float floatValue; + double doubleValue; double dotProduct = 0; for (Number value : queryVector) { - floatValue = value.floatValue(); - dotProduct += floatValue * floatValue; + doubleValue = value.doubleValue(); + dotProduct += doubleValue * doubleValue; } this.queryVectorMagnitude = Math.sqrt(dotProduct); } public double cosineSimilarity(VectorScriptDocValues dvs) { - if (dvs.getValue() == null) return 0; - float[] docVector = VectorEncoderDecoder.decodeDenseVector(dvs.getValue()); + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); // calculate docVector magnitude double dotProduct = 0f; for (int dim = 0; dim < docVector.length; dim++) { - dotProduct += docVector[dim] * docVector[dim]; + dotProduct += (double) docVector[dim] * docVector[dim]; } final double docVectorMagnitude = Math.sqrt(dotProduct); @@ -89,7 +92,7 @@ private static double intDotProduct(List v1, float[] v2){ int dim = 0; Iterator v1Iter = v1.iterator(); while(dim < dims) { - v1v2DotProduct += v1Iter.next().floatValue() * v2[dim]; + v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim]; dim++; } return v1v2DotProduct; @@ -107,7 +110,7 @@ private static double intDotProduct(List v1, float[] v2){ * A user will call `dotProductSparse(params.queryVector, doc['my_vector'])` */ public static final class DotProductSparse { - float[] queryValues; + double[] queryValues; int[] queryDims; // prepare queryVector once per script execution @@ -116,21 +119,26 @@ public DotProductSparse(Map queryVector) { //break vector into two arrays dims and values int n = queryVector.size(); queryDims = new int[n]; - queryValues = new float[n]; + queryValues = new double[n]; int i = 0; for (Map.Entry dimValue : queryVector.entrySet()) { - queryDims[i] = Integer.parseInt(dimValue.getKey()); - queryValues[i] = dimValue.getValue().floatValue(); + try { + queryDims[i] = Integer.parseInt(dimValue.getKey()); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e); + } + queryValues[i] = dimValue.getValue().doubleValue(); i++; } // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions - sortSparseDimsValues(queryDims, queryValues, n); + sortSparseDimsDoubleValues(queryDims, queryValues, n); } public double dotProductSparse(VectorScriptDocValues dvs) { - if (dvs.getValue() == null) return 0; - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(dvs.getValue()); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(dvs.getValue()); + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); return intDotProductSparse(queryValues, queryDims, docValues, docDims); } } @@ -144,7 +152,7 @@ public double dotProductSparse(VectorScriptDocValues dvs) { * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'])` */ public static final class CosineSimilaritySparse { - float[] queryValues; + double[] queryValues; int[] queryDims; double queryVectorMagnitude; @@ -152,30 +160,35 @@ public static final class CosineSimilaritySparse { public CosineSimilaritySparse(Map queryVector) { //break vector into two arrays dims and values int n = queryVector.size(); - queryValues = new float[n]; + queryValues = new double[n]; queryDims = new int[n]; double dotProduct = 0; int i = 0; for (Map.Entry dimValue : queryVector.entrySet()) { - queryDims[i] = Integer.parseInt(dimValue.getKey()); - queryValues[i] = dimValue.getValue().floatValue(); + try { + queryDims[i] = Integer.parseInt(dimValue.getKey()); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e); + } + queryValues[i] = dimValue.getValue().doubleValue(); dotProduct += queryValues[i] * queryValues[i]; i++; } this.queryVectorMagnitude = Math.sqrt(dotProduct); // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions - sortSparseDimsValues(queryDims, queryValues, n); + sortSparseDimsDoubleValues(queryDims, queryValues, n); } public double cosineSimilaritySparse(VectorScriptDocValues dvs) { - if (dvs.getValue() == null) return 0; - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(dvs.getValue()); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(dvs.getValue()); + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); // calculate docVector magnitude double dotProduct = 0; - for (float value : docValues) { - dotProduct += value * value; + for (float docValue : docValues) { + dotProduct += (double) docValue * docValue; } final double docVectorMagnitude = Math.sqrt(dotProduct); @@ -184,7 +197,7 @@ public double cosineSimilaritySparse(VectorScriptDocValues dvs) { } } - private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) { + private static double intDotProductSparse(double[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) { double v1v2DotProduct = 0; int v1Index = 0; int v2Index = 0; diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java index df654ae7f5113..f1645e8c17843 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java @@ -20,22 +20,26 @@ package org.elasticsearch.index.query; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.fielddata.AtomicFieldData; import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; +import java.io.IOException; import java.util.Collection; import java.util.Collections; final class VectorDVAtomicFieldData implements AtomicFieldData { - private final BinaryDocValues values; + private final LeafReader reader; + private final String field; - VectorDVAtomicFieldData(BinaryDocValues values) { - super(); - this.values = values; + public VectorDVAtomicFieldData(LeafReader reader, String field) { + this.reader = reader; + this.field = field; } @Override @@ -50,12 +54,17 @@ public Collection getChildResources() { @Override public SortedBinaryDocValues getBytesValues() { - return null; + throw new UnsupportedOperationException("String representation of doc values for vector fields is not supported"); } @Override public ScriptDocValues getScriptValues() { - return new VectorScriptDocValues(values); + try { + final BinaryDocValues values = DocValues.getBinary(reader, field); + return new VectorScriptDocValues(values); + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values for vector field!", e); + } } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java index 723342f6a1737..da660c5cbc8a0 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java @@ -19,7 +19,6 @@ package org.elasticsearch.index.query; -import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.SortField; import org.elasticsearch.common.Nullable; @@ -34,7 +33,6 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.search.MultiValueMode; -import java.io.IOException; public class VectorDVIndexFieldData extends DocValuesIndexFieldData implements IndexFieldData { @@ -49,11 +47,7 @@ public SortField sortField(@Nullable Object missingValue, MultiValueMode sortMod @Override public VectorDVAtomicFieldData load(LeafReaderContext context) { - try { - return new VectorDVAtomicFieldData(DocValues.getBinary(context.reader(), fieldName)); - } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values", e); - } + return new VectorDVAtomicFieldData(context.reader(), fieldName); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java index f83385f987611..547a81d2ce639 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java @@ -47,7 +47,7 @@ public void setNextDocId(int docId) throws IOException { } // package private access only for {@link ScoreScriptUtils} - BytesRef getValue() { + BytesRef getEncodedValue() { return value; } From e00f7d57c73ce6fe118ecff2533df1443673126c Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Tue, 5 Feb 2019 17:44:21 -0500 Subject: [PATCH 5/8] Correct check style --- .../org/elasticsearch/index/query/VectorDVAtomicFieldData.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java index f1645e8c17843..6fb8969d7aac3 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java @@ -37,7 +37,7 @@ final class VectorDVAtomicFieldData implements AtomicFieldData { private final LeafReader reader; private final String field; - public VectorDVAtomicFieldData(LeafReader reader, String field) { + VectorDVAtomicFieldData(LeafReader reader, String field) { this.reader = reader; this.field = field; } From f5a8ec40c4887b6ef75577ce96273037a0091141 Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Tue, 12 Feb 2019 10:47:53 -0500 Subject: [PATCH 6/8] Correct doc links to internal docs links --- docs/reference/mapping/types/dense-vector.asciidoc | 3 +-- docs/reference/mapping/types/sparse-vector.asciidoc | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index bd45874291c42..f656092e47211 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -9,8 +9,7 @@ not exceed 500. The number of dimensions can be different across documents. A `dense_vector` field is a single-valued field. -These vectors can be used for -{ref}/query-dsl-script-score-query.html#vector-functions[document scoring]. +These vectors can be used for <>. For example, a document score can represent a distance between a given query vector and the indexed document vector. diff --git a/docs/reference/mapping/types/sparse-vector.asciidoc b/docs/reference/mapping/types/sparse-vector.asciidoc index 739ce63d11026..8ed4920c4e652 100644 --- a/docs/reference/mapping/types/sparse-vector.asciidoc +++ b/docs/reference/mapping/types/sparse-vector.asciidoc @@ -9,8 +9,7 @@ not exceed 500. The number of dimensions can be different across documents. A `sparse_vector` field is a single-valued field. -These vectors can be used for -{ref}/query-dsl-script-score-query.html#vector-functions[document scoring]. +These vectors can be used for <>. For example, a document score can represent a distance between a given query vector and the indexed document vector. From f15c5107d8ad72cb7e7b2169db2c94a104bafa16 Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Wed, 13 Feb 2019 16:41:50 -0500 Subject: [PATCH 7/8] Separate docvalues for dense and sparse vectors --- .../index/mapper/DenseVectorFieldMapper.java | 2 +- .../index/mapper/SparseVectorFieldMapper.java | 2 +- .../index/query/ScoreScriptUtils.java | 8 ++--- .../index/query/VectorDVAtomicFieldData.java | 10 ++++-- .../index/query/VectorDVIndexFieldData.java | 13 ++++++-- .../index/query/VectorScriptDocValues.java | 18 +++++++++-- .../index/query/docvalues_whitelist.txt | 13 +++++--- .../test/dense-vector/10_basic.yml | 4 +-- .../test/dense-vector/20_special_cases.yml | 30 ++++++++++++++++-- .../test/sparse-vector/10_basic.yml | 4 +-- .../test/sparse-vector/20_special_cases.yml | 31 +++++++++++++++++-- 11 files changed, 108 insertions(+), 27 deletions(-) diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java index 57c996e977532..f4a61c3ebd358 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java @@ -120,7 +120,7 @@ public Query existsQuery(QueryShardContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) { - return new VectorDVIndexFieldData.Builder(); + return new VectorDVIndexFieldData.Builder(true); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java index 36146060ef33e..adf46d6a60d25 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java @@ -120,7 +120,7 @@ public Query existsQuery(QueryShardContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) { - return new VectorDVIndexFieldData.Builder(); + return new VectorDVIndexFieldData.Builder(false); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java index 98d4b8edec6c7..7f9d353fd700f 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java @@ -38,7 +38,7 @@ public class ScoreScriptUtils { * @param queryVector the query vector parsed as {@code List} from json * @param dvs VectorScriptDocValues representing encoded documents' vectors */ - public static double dotProduct(List queryVector, VectorScriptDocValues dvs){ + public static double dotProduct(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ BytesRef value = dvs.getEncodedValue(); if (value == null) return 0; float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); @@ -69,7 +69,7 @@ public CosineSimilarity(List queryVector) { this.queryVectorMagnitude = Math.sqrt(dotProduct); } - public double cosineSimilarity(VectorScriptDocValues dvs) { + public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { BytesRef value = dvs.getEncodedValue(); if (value == null) return 0; float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); @@ -134,7 +134,7 @@ public DotProductSparse(Map queryVector) { sortSparseDimsDoubleValues(queryDims, queryValues, n); } - public double dotProductSparse(VectorScriptDocValues dvs) { + public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { BytesRef value = dvs.getEncodedValue(); if (value == null) return 0; int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); @@ -179,7 +179,7 @@ public CosineSimilaritySparse(Map queryVector) { sortSparseDimsDoubleValues(queryDims, queryValues, n); } - public double cosineSimilaritySparse(VectorScriptDocValues dvs) { + public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { BytesRef value = dvs.getEncodedValue(); if (value == null) return 0; int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java index 6fb8969d7aac3..99e581ce4e514 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java @@ -36,10 +36,12 @@ final class VectorDVAtomicFieldData implements AtomicFieldData { private final LeafReader reader; private final String field; + private final boolean isDense; - VectorDVAtomicFieldData(LeafReader reader, String field) { + VectorDVAtomicFieldData(LeafReader reader, String field, boolean isDense) { this.reader = reader; this.field = field; + this.isDense = isDense; } @Override @@ -61,7 +63,11 @@ public SortedBinaryDocValues getBytesValues() { public ScriptDocValues getScriptValues() { try { final BinaryDocValues values = DocValues.getBinary(reader, field); - return new VectorScriptDocValues(values); + if (isDense) { + return new VectorScriptDocValues.DenseVectorScriptDocValues(values); + } else { + return new VectorScriptDocValues.SparseVectorScriptDocValues(values); + } } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for vector field!", e); } diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java index da660c5cbc8a0..9badf9f11b443 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java @@ -35,9 +35,11 @@ public class VectorDVIndexFieldData extends DocValuesIndexFieldData implements IndexFieldData { + private final boolean isDense; - public VectorDVIndexFieldData(Index index, String fieldName) { + public VectorDVIndexFieldData(Index index, String fieldName, boolean isDense) { super(index, fieldName); + this.isDense = isDense; } @Override @@ -47,7 +49,7 @@ public SortField sortField(@Nullable Object missingValue, MultiValueMode sortMod @Override public VectorDVAtomicFieldData load(LeafReaderContext context) { - return new VectorDVAtomicFieldData(context.reader(), fieldName); + return new VectorDVAtomicFieldData(context.reader(), fieldName, isDense); } @Override @@ -56,11 +58,16 @@ public VectorDVAtomicFieldData loadDirect(LeafReaderContext context) throws Exce } public static class Builder implements IndexFieldData.Builder { + private final boolean isDense; + public Builder(boolean isDense) { + this.isDense = isDense; + } + @Override public IndexFieldData build(IndexSettings indexSettings, MappedFieldType fieldType, IndexFieldDataCache cache, CircuitBreakerService breakerService, MapperService mapperService) { final String fieldName = fieldType.name(); - return new VectorDVIndexFieldData(indexSettings.getIndex(), fieldName); + return new VectorDVIndexFieldData(indexSettings.getIndex(), fieldName, isDense); } } diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java index 547a81d2ce639..59347af6f8f6c 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java @@ -28,7 +28,7 @@ /** * VectorScriptDocValues represents docValues for dense and sparse vector fields */ -public final class VectorScriptDocValues extends ScriptDocValues { +public class VectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues in; private BytesRef value; @@ -53,12 +53,24 @@ BytesRef getEncodedValue() { @Override public BytesRef get(int index) { - throw new UnsupportedOperationException("this operation is not supported on the doc values of the vector field"); + throw new UnsupportedOperationException("vector fields may only be used via vector functions in scripts"); } @Override public int size() { - throw new UnsupportedOperationException("this operation is not supported on the doc values of the vector field"); + throw new UnsupportedOperationException("vector fields may only be used via vector functions in scripts"); + } + + public static final class DenseVectorScriptDocValues extends VectorScriptDocValues { + public DenseVectorScriptDocValues(BinaryDocValues in) { + super(in); + } + } + + public static final class SparseVectorScriptDocValues extends VectorScriptDocValues { + public SparseVectorScriptDocValues(BinaryDocValues in) { + super(in); + } } } diff --git a/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt index 62874ae46a7f8..3a8989e20b020 100644 --- a/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt +++ b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt @@ -17,13 +17,16 @@ # under the License. # - class org.elasticsearch.index.query.VectorScriptDocValues { } +class org.elasticsearch.index.query.VectorScriptDocValues$DenseVectorScriptDocValues { +} +class org.elasticsearch.index.query.VectorScriptDocValues$SparseVectorScriptDocValues { +} static_import { - double cosineSimilarity(List, VectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilarity - double dotProduct(List, VectorScriptDocValues) from_class org.elasticsearch.index.query.ScoreScriptUtils - double dotProductSparse(Map, VectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$DotProductSparse - double cosineSimilaritySparse(Map, VectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilaritySparse + double cosineSimilarity(List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) from_class org.elasticsearch.index.query.ScoreScriptUtils + double dotProductSparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilaritySparse } \ No newline at end of file diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml index f2ad60589f044..e5db535b69b80 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml @@ -1,8 +1,8 @@ setup: - skip: features: headers - version: " - 6.99.99" - reason: "dense_vector field was introduced in 7.0.0" + version: " - 7.0.99" + reason: "dense_vector functions were introduced in 7.1.0" - do: indices.create: diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml index ce9e0052d5022..dc7e55ac0df4e 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml @@ -1,8 +1,8 @@ setup: - skip: features: headers - version: " - 6.99.99" - reason: "dense_vector field was introduced in 7.0.0" + version: " - 7.0.99" + reason: "dense_vector functions were introduced in 7.1.0" - do: indices.create: @@ -122,3 +122,29 @@ setup: - match: {hits.hits.1._id: "2"} - match: {hits.hits.1._score: 0.0} +--- +"Dense vectors should error with sparse vector functions": +- do: + index: + index: test-index + id: 1 + body: + my_dense_vector: [10, 2, 0.15] + +- do: + indices.refresh: {} + +- do: + catch: bad_request + headers: + Content-Type: application/json + search: + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: {"2": 0.5, "10" : 111.3} +- match: { error.root_cause.0.type: "script_exception" } diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml index 6af9d0c2ec28f..142a80291aebf 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml @@ -1,8 +1,8 @@ setup: - skip: features: headers - version: " - 6.99.99" - reason: "sparse_vector field was introduced in 7.0.0" + version: " - 7.0.99" + reason: "sparse_vector functions were introduced in 7.1.0" - do: indices.create: diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml index 5f6c5a3cffe9f..5b6cc9753f212 100644 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml @@ -1,8 +1,8 @@ setup: - skip: features: headers - version: " - 6.99.99" - reason: "sparse_vector field was introduced in 7.0.0" + version: " - 7.0.99" + reason: "sparse_vector functions were introduced in 7.1.0" - do: indices.create: @@ -172,3 +172,30 @@ setup: - lte: {hits.hits.1._score: 1.001} - gte: {hits.hits.2._score: 0.99} - lte: {hits.hits.2._score: 1.001} + +--- +"Sparse vectors should error with dense vector functions": +- do: + index: + index: test-index + id: 1 + body: + my_sparse_vector: {"100": -200.0, "30": 15.555} + +- do: + indices.refresh: {} + +- do: + catch: bad_request + headers: + Content-Type: application/json + search: + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: [0.5, 111] +- match: { error.root_cause.0.type: "script_exception" } From 16412f85bbb96a2abe832f350ae0fd88d7f42780 Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Fri, 15 Feb 2019 06:59:33 -0500 Subject: [PATCH 8/8] Add unit tests for script vector functions --- .../index/mapper/VectorEncoderDecoder.java | 2 +- .../index/query/ScoreScriptUtils.java | 12 +-- .../index/query/VectorScriptDocValues.java | 8 +- .../mapper/VectorEncoderDecoderTests.java | 2 +- .../index/query/ScoreScriptUtilsTests.java | 82 +++++++++++++++++++ 5 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java index a35ca16b01427..fbf9955f46621 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java @@ -38,7 +38,7 @@ private VectorEncoderDecoder() { } * and may be over-allocated * @return BytesRef */ - static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { + public static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { // 1. Sort dims and values sortSparseDimsValues(dims, values, dimCount); byte[] buf = new byte[dimCount * (INT_BYTES + SHORT_BYTES)]; diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java index 7f9d353fd700f..93e80d2a653fb 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java @@ -55,7 +55,7 @@ public static double dotProduct(List queryVector, VectorScriptDocValues. */ public static final class CosineSimilarity { final double queryVectorMagnitude; - List queryVector; + final List queryVector; // calculate queryVectorMagnitude once per query execution public CosineSimilarity(List queryVector) { @@ -110,8 +110,8 @@ private static double intDotProduct(List v1, float[] v2){ * A user will call `dotProductSparse(params.queryVector, doc['my_vector'])` */ public static final class DotProductSparse { - double[] queryValues; - int[] queryDims; + final double[] queryValues; + final int[] queryDims; // prepare queryVector once per script execution // queryVector represents a map of dimensions to values @@ -152,9 +152,9 @@ public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'])` */ public static final class CosineSimilaritySparse { - double[] queryValues; - int[] queryDims; - double queryVectorMagnitude; + final double[] queryValues; + final int[] queryDims; + final double queryVectorMagnitude; // prepare queryVector once per script execution public CosineSimilaritySparse(Map queryVector) { diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java index 59347af6f8f6c..603881d390718 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java @@ -28,7 +28,7 @@ /** * VectorScriptDocValues represents docValues for dense and sparse vector fields */ -public class VectorScriptDocValues extends ScriptDocValues { +public abstract class VectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues in; private BytesRef value; @@ -61,13 +61,15 @@ public int size() { throw new UnsupportedOperationException("vector fields may only be used via vector functions in scripts"); } - public static final class DenseVectorScriptDocValues extends VectorScriptDocValues { + // not final, as it needs to be extended by Mockito for tests + public static class DenseVectorScriptDocValues extends VectorScriptDocValues { public DenseVectorScriptDocValues(BinaryDocValues in) { super(in); } } - public static final class SparseVectorScriptDocValues extends VectorScriptDocValues { + // not final, as it needs to be extended by Mockito for tests + public static class SparseVectorScriptDocValues extends VectorScriptDocValues { public SparseVectorScriptDocValues(BinaryDocValues in) { super(in); } diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java index 67ab78261375e..9b8a741192c4f 100644 --- a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java @@ -83,7 +83,7 @@ public void testSparseVectorEncodingDecoding() { } // imitates the code in DenseVectorFieldMapper::parse - private BytesRef mockEncodeDenseVector(float[] dims) { + public static BytesRef mockEncodeDenseVector(float[] dims) { final short INT_BYTES = VectorEncoderDecoder.INT_BYTES; byte[] buf = new byte[INT_BYTES * dims.length]; int offset = 0; diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java new file mode 100644 index 0000000000000..bcdf0387c3f71 --- /dev/null +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java @@ -0,0 +1,82 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.VectorEncoderDecoder; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.index.query.ScoreScriptUtils.CosineSimilarity; +import org.elasticsearch.index.query.ScoreScriptUtils.DotProductSparse; +import org.elasticsearch.index.query.ScoreScriptUtils.CosineSimilaritySparse; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.index.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector; +import static org.elasticsearch.index.query.ScoreScriptUtils.dotProduct; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + + +public class ScoreScriptUtilsTests extends ESTestCase { + public void testDenseVectorFunctions() { + float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = mockEncodeDenseVector(docVector); + VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + List queryVector = Arrays.asList(0.5, 111.3, -13.0, 14.8, -156.0); + + // test dotProduct + double result = dotProduct(queryVector, dvs); + assertEquals("dotProduct result is not equal to the expected value!", 65425.62, result, 0.1); + + // test cosineSimilarity + CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector); + double result2 = cosineSimilarity.cosineSimilarity(dvs); + assertEquals("cosineSimilarity result is not equal to the expected value!", 0.78, result2, 0.1); + } + + public void testSparseVectorFunctions() { + int[] docVectorDims = {2, 10, 50, 113, 4545}; + float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + Map queryVector = new HashMap() {{ + put("2", 0.5); + put("10", 111.3); + put("50", -13.0); + put("113", 14.8); + put("4545", -156.0); + }}; + + // test dotProduct + DotProductSparse docProductSparse = new DotProductSparse(queryVector); + double result = docProductSparse.dotProductSparse(dvs); + assertEquals("dotProductSparse result is not equal to the expected value!", 65425.62, result, 0.1); + + // test cosineSimilarity + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.78, result2, 0.1); + } +}