diff --git a/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc b/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc index d355a495e0625..7916127264204 100644 --- a/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc +++ b/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc @@ -10,8 +10,8 @@ The following specialized API is available in the Score context. ==== Static Methods The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values. -* double cosineSimilarity(List *, VectorScriptDocValues.DenseVectorScriptDocValues) -* double cosineSimilaritySparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) +* double cosineSimilarity(List *, String *) +* double cosineSimilaritySparse(Map *, String *) * double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime) @@ -21,8 +21,8 @@ The following methods are directly callable without a class/instance qualifier. * double decayNumericExp(double *, double *, double *, double *, double) * double decayNumericGauss(double *, double *, double *, double *, double) * double decayNumericLinear(double *, double *, double *, double *, double) -* double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) -* double dotProductSparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) +* double dotProduct(List, String *) +* double dotProductSparse(Map *, String *) * double randomScore(int *) * double randomScore(int *, String *) * double saturation(double, double) diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 4a23703b7ae6c..f03f59ee6a47a 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -68,7 +68,7 @@ GET my_index/_search } }, "script": { - "source": "cosineSimilarity(params.query_vector, doc['my_dense_vector']) + 1.0", <2> + "source": "cosineSimilarity(params.query_vector, 'my_dense_vector') + 1.0", <2> "params": { "query_vector": [4, 3.4, -0.2] <3> } @@ -105,7 +105,7 @@ GET my_index/_search }, "script": { "source": """ - double value = dotProduct(params.query_vector, doc['my_dense_vector']); + double value = dotProduct(params.query_vector, 'my_dense_vector'); return sigmoid(1, Math.E, -value); <1> """, "params": { @@ -139,7 +139,7 @@ GET my_index/_search } }, "script": { - "source": "1 / (1 + l1norm(params.queryVector, doc['my_dense_vector']))", <1> + "source": "1 / (1 + l1norm(params.queryVector, 'my_dense_vector'))", <1> "params": { "queryVector": [4, 3.4, -0.2] } @@ -178,7 +178,7 @@ GET my_index/_search } }, "script": { - "source": "1 / (1 + l2norm(params.queryVector, doc['my_dense_vector']))", + "source": "1 / (1 + l2norm(params.queryVector, 'my_dense_vector'))", "params": { "queryVector": [4, 3.4, -0.2] } @@ -262,7 +262,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector']) + 1.0", + "source": "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector') + 1.0", "params": { "query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } @@ -294,7 +294,7 @@ GET my_sparse_index/_search }, "script": { "source": """ - double value = dotProductSparse(params.query_vector, doc['my_sparse_vector']); + double value = dotProductSparse(params.query_vector, 'my_sparse_vector'); return sigmoid(1, Math.E, -value); """, "params": { @@ -327,7 +327,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "1 / (1 + l1normSparse(params.queryVector, doc['my_sparse_vector']))", + "source": "1 / (1 + l1normSparse(params.queryVector, 'my_sparse_vector'))", "params": { "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } @@ -358,7 +358,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "1 / (1 + l2normSparse(params.queryVector, doc['my_sparse_vector']))", + "source": "1 / (1 + l2normSparse(params.queryVector, 'my_sparse_vector'))", "params": { "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index faad66fc1479b..7c2c09d17afe7 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -109,7 +109,7 @@ public Map getParams() { } /** The doc lookup for the Lucene segment this script was created for. */ - public final Map> getDoc() { + public Map> getDoc() { return leafLookup.doc(); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml index 903b9dc3de3b0..383bc96ccb0d0 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml @@ -52,7 +52,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + source: "dotProduct(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -82,7 +82,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml index dbb274d077645..882d11566dfaa 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml @@ -53,7 +53,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1norm(params.query_vector, doc['my_dense_vector'])" + source: "l1norm(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -83,7 +83,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2norm(params.query_vector, doc['my_dense_vector'])" + source: "l2norm(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml index 98a68cab9ca0a..cfec55095ad9d 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml @@ -62,7 +62,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10, 10, 10] @@ -81,7 +81,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] @@ -111,7 +111,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [1, 2, 3, 4] - match: { error.root_cause.0.type: "script_exception" } @@ -125,7 +125,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + source: "dotProduct(params.query_vector, 'my_dense_vector')" params: query_vector: [1, 2, 3, 4] - match: { error.root_cause.0.type: "script_exception" } @@ -161,7 +161,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] - match: { error.root_cause.0.type: "script_exception" } @@ -177,7 +177,7 @@ setup: script_score: query: {match_all: {} } script: - source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] @@ -209,7 +209,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" + source: "dotProductSparse(params.query_vector, 'my_dense_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "3": 44} - match: { error.root_cause.0.type: "script_exception" } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml index e184fd0ce9333..406f9b9db4a2a 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml @@ -55,7 +55,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + source: "dotProductSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -87,7 +87,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml index 3a6ed9fd561e9..8a1ec0d3cdde3 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml @@ -55,7 +55,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l1normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -88,7 +88,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l2normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml index 90a28eeb1eeae..c49413097807c 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml @@ -61,7 +61,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10} @@ -83,7 +83,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} @@ -127,7 +127,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} - match: { error.root_cause.0.type: "script_exception" } @@ -145,7 +145,7 @@ setup: script_score: query: {match_all: {} } script: - source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} @@ -194,7 +194,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} @@ -230,7 +230,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_sparse_vector'])" + source: "dotProduct(params.query_vector, 'my_sparse_vector')" params: query_vector: [0.5, 111] - match: { error.root_cause.0.type: "script_exception" } @@ -273,7 +273,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + source: "dotProductSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} @@ -304,7 +304,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5" : 5} @@ -334,7 +334,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l1normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} @@ -361,7 +361,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l2normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index 91f2fc343b113..4d7b640376a27 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -9,12 +9,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -28,12 +30,13 @@ public class ScoreScriptUtils { // Also, constructors for some functions accept queryVector to calculate and cache queryVectorMagnitude only once // per script execution for all documents. - public static class DenseVectorFunction { + static class DenseVectorFunction { final ScoreScript scoreScript; final float[] queryVector; + final VectorScriptDocValues.DenseVectorScriptDocValues docValues; - public DenseVectorFunction(ScoreScript scoreScript, List queryVector) { - this(scoreScript, queryVector, false); + DenseVectorFunction(ScoreScript scoreScript, List queryVector, String fieldName) { + this(scoreScript, queryVector, fieldName, false); } /** @@ -43,10 +46,12 @@ public DenseVectorFunction(ScoreScript scoreScript, List queryVector) { * @param queryVector The query vector. * @param normalizeQuery Whether the provided query should be normalized to unit length. */ - public DenseVectorFunction(ScoreScript scoreScript, - List queryVector, - boolean normalizeQuery) { + DenseVectorFunction(ScoreScript scoreScript, + List queryVector, + String fieldName, + boolean normalizeQuery) { this.scoreScript = scoreScript; + this.docValues = (VectorScriptDocValues.DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); this.queryVector = new float[queryVector.size()]; double queryMagnitude = 0.0; @@ -64,7 +69,14 @@ public DenseVectorFunction(ScoreScript scoreScript, } } - public void validateDocVector(BytesRef vector) { + BytesRef getEncodedVector() { + try { + docValues.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + + BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } @@ -74,19 +86,19 @@ public void validateDocVector(BytesRef vector) { throw new IllegalArgumentException("The query vector has a different number of dimensions [" + queryVector.length + "] than the document vectors [" + vectorLength + "]."); } + return vector; } } // Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors public static final class L1Norm extends DenseVectorFunction { - public L1Norm(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public L1Norm(ScoreScript scoreScript, List queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); } - public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l1norm() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double l1norm = 0; @@ -100,13 +112,12 @@ public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { // Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors public static final class L2Norm extends DenseVectorFunction { - public L2Norm(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public L2Norm(ScoreScript scoreScript, List queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); } - public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l2norm() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double l2norm = 0; @@ -121,13 +132,12 @@ public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { // Calculate a dot product between a query's dense vector and documents' dense vectors public static final class DotProduct extends DenseVectorFunction { - public DotProduct(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public DotProduct(ScoreScript scoreScript, List queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); } - public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double dotProduct() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0; @@ -141,14 +151,12 @@ public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ // Calculate cosine similarity between a query's dense vector and documents' dense vectors public static final class CosineSimilarity extends DenseVectorFunction { - public CosineSimilarity(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector, true); + public CosineSimilarity(ScoreScript scoreScript, List queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName, true); } - public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double cosineSimilarity() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0.0; @@ -174,18 +182,21 @@ public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. // Also, constructors for some functions accept queryVector to calculate and cache queryVectorMagnitude only once // per script execution for all documents. - - public static class SparseVectorFunction { - static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class)); + static class SparseVectorFunction { + private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class)); final ScoreScript scoreScript; final float[] queryValues; final int[] queryDims; + final VectorScriptDocValues.SparseVectorScriptDocValues docValues; + // prepare queryVector once per script execution // queryVector represents a map of dimensions to values - public SparseVectorFunction(ScoreScript scoreScript, Map queryVector) { + SparseVectorFunction(ScoreScript scoreScript, Map queryVector, String fieldName) { this.scoreScript = scoreScript; + this.docValues = (VectorScriptDocValues.SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); + //break vector into two arrays dims and values int n = queryVector.size(); queryValues = new float[n]; @@ -206,22 +217,29 @@ public SparseVectorFunction(ScoreScript scoreScript, Map queryVe deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE); } - public void validateDocVector(BytesRef vector) { + BytesRef getEncodedVector() { + try { + docValues.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + + BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + return vector; } } // Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors public static final class L1NormSparse extends SparseVectorFunction { - public L1NormSparse(ScoreScript scoreScript,Map queryVector) { - super(scoreScript, queryVector); + public L1NormSparse(ScoreScript scoreScript, Map queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); } - public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l1normSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); @@ -255,13 +273,13 @@ public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs // Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors public static final class L2NormSparse extends SparseVectorFunction { - public L2NormSparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + + public L2NormSparse(ScoreScript scoreScript, Map queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); } - public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l2normSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); @@ -298,13 +316,12 @@ public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs // Calculate a dot product between a query's sparse vector and documents' sparse vectors public static final class DotProductSparse extends SparseVectorFunction { - public DotProductSparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public DotProductSparse(ScoreScript scoreScript, Map queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); } - public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double dotProductSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); @@ -316,8 +333,8 @@ public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues public static final class CosineSimilaritySparse extends SparseVectorFunction { final double queryVectorMagnitude; - public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector, String fieldName) { + super(scoreScript, queryVector, fieldName); double dotProduct = 0; for (int i = 0; i< queryDims.length; i++) { dotProduct += queryValues[i] * queryValues[i]; @@ -325,9 +342,8 @@ public CosineSimilaritySparse(ScoreScript scoreScript, Map query this.queryVectorMagnitude = Math.sqrt(dotProduct); } - public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double cosineSimilaritySparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt index 42d6e6d0b0f7a..33abe41fcce0f 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt @@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import { } static_import { - double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm - double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm - double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity - double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct - double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse - double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse - double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse - double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse + double l1norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm + double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm + double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct + double l1normSparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse + double l2normSparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse + double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse } \ No newline at end of file diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java index bff87a5ac472c..2e101abf9ae6b 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,55 +40,57 @@ public void testDenseVectorFunctions() { } private void testDenseVectorFunctions(Version indexVersion) { + String fieldName = "vector"; float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); + BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(indexVersion); - - List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(fieldName, dvs)); // test dotProduct - DotProduct dotProduct = new DotProduct(scoreScript, queryVector); - double result = dotProduct.dotProduct(dvs); + DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName); + double result = dotProduct.dotProduct(); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector); - double result2 = cosineSimilarity.cosineSimilarity(dvs); + CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, fieldName); + double result2 = cosineSimilarity.cosineSimilarity(); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); // test l1Norm - L1Norm l1norm = new L1Norm(scoreScript, queryVector); - double result3 = l1norm.l1norm(dvs); + L1Norm l1norm = new L1Norm(scoreScript, queryVector, fieldName); + double result3 = l1norm.l1norm(); assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); // test l2norm - L2Norm l2norm = new L2Norm(scoreScript, queryVector); - double result4 = l2norm.l2norm(dvs); + L2Norm l2norm = new L2Norm(scoreScript, queryVector, fieldName); + double result4 = l2norm.l2norm(); assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); // test dotProduct fails when queryVector has wrong number of dims List invalidQueryVector = Arrays.asList(0.5, 111.3); - DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs)); + DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, fieldName); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test cosineSimilarity fails when queryVector has wrong number of dims - CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs)); + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, fieldName); + e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test l1norm fails when queryVector has wrong number of dims - L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs)); + L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, fieldName); + e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test l2norm fails when queryVector has wrong number of dims - L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs)); + L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, fieldName); + e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -97,14 +100,18 @@ public void testSparseVectorFunctions() { } private void testSparseVectorFunctions(Version indexVersion) { + String fieldName = "vector"; int[] docVectorDims = {2, 10, 50, 113, 4545}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( indexVersion, docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(fieldName, dvs)); Map queryVector = new HashMap() {{ put("2", 0.5); @@ -115,23 +122,23 @@ private void testSparseVectorFunctions(Version indexVersion) { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, fieldName); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, fieldName); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, fieldName); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, fieldName); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); @@ -139,14 +146,19 @@ private void testSparseVectorFunctions(Version indexVersion) { public void testSparseVectorMissingDimensions1() { // Document vector's biggest dimension > query vector's biggest dimension + String fieldName = "vector"; int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(fieldName, dvs)); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -157,23 +169,23 @@ public void testSparseVectorMissingDimensions1() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, fieldName); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, fieldName); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, fieldName); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, fieldName); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); @@ -181,14 +193,19 @@ public void testSparseVectorMissingDimensions1() { public void testSparseVectorMissingDimensions2() { // Document vector's biggest dimension < query vector's biggest dimension + String fieldName = "vector"; int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(fieldName, dvs)); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -199,23 +216,23 @@ public void testSparseVectorMissingDimensions2() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, fieldName); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, fieldName); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, fieldName); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, fieldName); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);