diff --git a/docs/reference/query-dsl/script-score-query.asciidoc b/docs/reference/query-dsl/script-score-query.asciidoc new file mode 100644 index 0000000000000..5e8bf560140a8 --- /dev/null +++ b/docs/reference/query-dsl/script-score-query.asciidoc @@ -0,0 +1,303 @@ +[[query-dsl-script-score-query]] +=== Script Score Query + +experimental[] + +The `script_score` allows you to modify the score of documents that are +retrieved by a query. This can be useful if, for example, a score +function is computationally expensive and it is sufficient to compute +the score on a filtered set of documents. + +To use `script_score`, you have to define a query and a script - +a function to be used to compute a new score for each document returned +by the query. For more information on scripting see +<>. + + +Here is an example of using `script_score` to assign each matched document +a score equal to the number of likes divided by 10: + +[source,js] +-------------------------------------------------- +GET /_search +{ + "query" : { + "script_score" : { + "query" : { + "match": { "message": "elasticsearch" } + }, + "script" : { + "source" : "doc['likes'].value / 10 " + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:twitter] + +==== Accessing the score of a document within a script + +Within a script, you can +<> +the `_score` variable which represents the current relevance score of a +document. + + +==== Predefined functions within a Painless script +You can use any of the available +<> in the painless script. +Besides these functions, there are a number of predefined functions +that can help you with scoring. We suggest you to use them instead of +rewriting equivalent functions of your own, as these functions try +to be the most efficient by using the internal mechanisms. + +===== rational +latexmath:[rational(value,k) = value/(k + value)] + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "rational(doc['likes'].value, 1)" +} +-------------------------------------------------- +// NOTCONSOLE + +===== sigmoid +latexmath:[sigmoid(value, k, a) = value^a/ (k^a + value^a)] + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "sigmoid(doc['likes'].value, 2, 1)" +} +-------------------------------------------------- +// NOTCONSOLE + + +[[random-functions]] +===== Random functions +There are two predefined ways to produce random values: + +1. `randomNotReproducible()` uses `java.util.Random` class +to generate a random value of the type `long`. +The generated values are not reproducible between requests' invocations. + + [source,js] + -------------------------------------------------- + "script" : { + "source" : "randomNotReproducible()" + } + -------------------------------------------------- + // NOTCONSOLE + + +2. `randomReproducible(String seedValue, int seed)` produces +reproducible random values of type `long`. This function requires +more computational time and memory than the non-reproducible version. + +A good candidate for the `seedValue` is document field values that +are unique across documents and already pre-calculated and preloaded +in the memory. For example, values of the document's `_seq_no` field +is a good candidate, as documents on the same shard have unique values +for the `_seq_no` field. + + [source,js] + -------------------------------------------------- + "script" : { + "source" : "randomReproducible(Long.toString(doc['_seq_no'].value), 100)" + } + -------------------------------------------------- + // NOTCONSOLE + + +A drawback of using `_seq_no` is that generated values change if +documents are updated. Another drawback is not absolute uniqueness, as +documents from different shards with the same sequence numbers +generate the same random values. + +If you need random values to be distinct across different shards, +you can use a field with unique values across shards, +such as `_id`, but watch out for the memory usage as all +these unique values need to be loaded into memory. + + [source,js] + -------------------------------------------------- + "script" : { + "source" : "randomReproducible(doc['_id'].value, 100)" + } + -------------------------------------------------- + // NOTCONSOLE + + +[[decay-functions]] +===== Decay functions for numeric fields +You can read more about decay functions +<>. + +* `double decayNumericLinear(double origin, double scale, double offset, double decay, double docValue)` +* `double decayNumericExp(double origin, double scale, double offset, double decay, double docValue)` +* `double decayNumericGauss(double origin, double scale, double offset, double decay, double docValue)` + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "decayNumericLinear(params.origin, params.scale, params.offset, params.decay, doc['dval'].value)", + "params": { <1> + "origin": 20, + "scale": 10, + "decay" : 0.5, + "offset" : 0 + } +} +-------------------------------------------------- +// NOTCONSOLE +<1> Use `params` to compile a script only once for different values of parameters + + +===== Decay functions for geo fields + +* `double decayGeoLinear(String originStr, String scaleStr, String offsetStr, double decay, GeoPoint docValue)` + +* `double decayGeoExp(String originStr, String scaleStr, String offsetStr, double decay, GeoPoint docValue)` + +* `double decayGeoGauss(String originStr, String scaleStr, String offsetStr, double decay, GeoPoint docValue)` + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "decayGeoExp(params.origin, params.scale, params.offset, params.decay, doc['location'].value)", + "params": { + "origin": "40, -70.12", + "scale": "200km", + "offset": "0km", + "decay" : 0.2 + } +} +-------------------------------------------------- +// NOTCONSOLE + + +===== Decay functions for date fields + +* `double decayDateLinear(String originStr, String scaleStr, String offsetStr, double decay, JodaCompatibleZonedDateTime docValueDate)` + +* `double decayDateExp(String originStr, String scaleStr, String offsetStr, double decay, JodaCompatibleZonedDateTime docValueDate)` + +* `double decayDateGauss(String originStr, String scaleStr, String offsetStr, double decay, JodaCompatibleZonedDateTime docValueDate)` + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "decayDateGauss(params.origin, params.scale, params.offset, params.decay, doc['date'].value)", + "params": { + "origin": "2008-01-01T01:00:00Z", + "scale": "1h", + "offset" : "0", + "decay" : 0.5 + } +} +-------------------------------------------------- +// NOTCONSOLE + +NOTE: Decay functions on dates are limited to dates in the default format +and default time zone. Also calculations with `now` are not supported. + + +==== Faster alternatives +Script Score Query calculates the score for every hit (matching document). +There are faster alternative query types that can efficiently skip +non-competitive hits: + +* If you want to boost documents on some static fields, use + <>. + + +==== Transition from Function Score Query +We are deprecating <>, and +Script Score Query will be a substitute for it. + +Here we describe how Function Score Query's functions can be +equivalently implemented in Script Score Query: + +===== `script_score` +What you used in `script_score` of the Function Score query, you +can copy into the Script Score query. No changes here. + +===== `weight` +`weight` function can be implemented in the Script Score query through +the following script: + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "params.weight * _score", + "params": { + "weight": 2 + } +} +-------------------------------------------------- +// NOTCONSOLE + +===== `random_score` + +Use `randomReproducible` and `randomNotReproducible` functions +as described in <>. + + +===== `field_value_factor` +`field_value_factor` function can be easily implemented through script: + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "Math.log10(doc['field'].value * params.factor)", + params" : { + "factor" : 5 + } +} +-------------------------------------------------- +// NOTCONSOLE + + +For checking if a document has a missing value, you can use +`doc['field'].size() == 0`. For example, this script will use +a value `1` if a document doesn't have a field `field`: + +[source,js] +-------------------------------------------------- +"script" : { + "source" : "Math.log10((doc['field'].size() == 0 ? 1 : doc['field'].value()) * params.factor)", + params" : { + "factor" : 5 + } +} +-------------------------------------------------- +// NOTCONSOLE + +This table lists how `field_value_factor` modifiers can be implemented +through a script: + +[cols="<,<",options="header",] +|======================================================================= +| Modifier | Implementation in Script Score + +| `none` | - +| `log` | `Math.log10(doc['f'].value)` +| `log1p` | `Math.log10(doc['f'].value + 1)` +| `log2p` | `Math.log10(doc['f'].value + 2)` +| `ln` | `Math.log(doc['f'].value)` +| `ln1p` | `Math.log(doc['f'].value + 1)` +| `ln2p` | `Math.log(doc['f'].value + 2)` +| `square` | `Math.pow(doc['f'].value, 2)` +| `sqrt` | `Math.sqrt(doc['f'].value)` +| `reciprocal` | `1.0 / doc['f'].value` +|======================================================================= + + +===== `decay functions` +Script Score query has equivalent <> +that can be used in script. + + + diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessPlugin.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessPlugin.java index 1773b3445c429..2a4a1e368ada3 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessPlugin.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessPlugin.java @@ -38,6 +38,7 @@ import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.search.aggregations.pipeline.MovingFunctionScript; @@ -72,6 +73,11 @@ public final class PainlessPlugin extends Plugin implements ScriptPlugin, Extens movFn.add(WhitelistLoader.loadFromResourceFiles(Whitelist.class, "org.elasticsearch.aggs.movfn.txt")); map.put(MovingFunctionScript.CONTEXT, movFn); + // Functions used for scoring docs + List scoreFn = new ArrayList<>(Whitelist.BASE_WHITELISTS); + scoreFn.add(WhitelistLoader.loadFromResourceFiles(Whitelist.class, "org.elasticsearch.score.txt")); + map.put(ScoreScript.CONTEXT, scoreFn); + whitelists = map; } diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/org.elasticsearch.score.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/org.elasticsearch.score.txt new file mode 100644 index 0000000000000..3aa32eff9c7a2 --- /dev/null +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/org.elasticsearch.score.txt @@ -0,0 +1,38 @@ +# +# Licensed to Elasticsearch under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This file contains a whitelist for functions to be used in Score context + +static_import { + double rational(double, double) from_class org.elasticsearch.script.ScoreScriptUtils + double sigmoid(double, double, double) from_class org.elasticsearch.script.ScoreScriptUtils + double randomReproducible(String, int) from_class org.elasticsearch.script.ScoreScriptUtils + double randomNotReproducible() bound_to org.elasticsearch.script.ScoreScriptUtils$RandomNotReproducible + double decayGeoLinear(String, String, String, double, GeoPoint) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayGeoLinear + double decayGeoExp(String, String, String, double, GeoPoint) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayGeoExp + double decayGeoGauss(String, String, String, double, GeoPoint) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayGeoGauss + double decayNumericLinear(double, double, double, double, double) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayNumericLinear + double decayNumericExp(double, double, double, double, double) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayNumericExp + double decayNumericGauss(double, double, double, double, double) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayNumericGauss + double decayDateLinear(String, String, String, double, JodaCompatibleZonedDateTime) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayDateLinear + double decayDateExp(String, String, String, double, JodaCompatibleZonedDateTime) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayDateExp + double decayDateGauss(String, String, String, double, JodaCompatibleZonedDateTime) bound_to org.elasticsearch.script.ScoreScriptUtils$DecayDateGauss +} + + diff --git a/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/80_script_score.yml b/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/80_script_score.yml new file mode 100644 index 0000000000000..d6f52c517d62f --- /dev/null +++ b/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/80_script_score.yml @@ -0,0 +1,484 @@ +# Integration tests for ScriptScoreQuery using Painless + +setup: +- skip: + version: " - 6.99.99" + reason: "script score query was introduced in 7.0.0" + +--- +"Random functions": + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 2 + mappings: + _doc: + properties: + f1: + type: keyword + - do: + index: + index: test + type: _doc + id: 1 + body: {"f1": "v1"} + - do: + index: + index: test + type: _doc + id: 2 + body: {"f1": "v2"} + - do: + index: + index: test + type: _doc + id: 3 + body: {"f1": "v3"} + + - do: + indices.refresh: {} + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "randomReproducible(Long.toString(doc['_seq_no'].value), 100)" + - match: { hits.total: 3 } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "randomNotReproducible()" + - match: { hits.total: 3 } + +--- +"Decay geo functions": + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + mappings: + _doc: + properties: + text-location: + type: keyword + location: + type: geo_point + - do: + index: + index: test + type: _doc + id: 1 + body: { "text-location": "location1", "location" : {"lat" : 40.24, "lon" : -70.24} } + - do: + index: + index: test + type: _doc + id: 2 + body: { "text-location": "location2", "location" : {"lat" : 40.12, "lon" : -70.12} } + - do: + indices.refresh: {} + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayGeoLinear(params.origin, params.scale, params.offset, params.decay, doc['location'].value)" + params: + origin: "40, -70" + scale: "200km" + offset: "0km" + decay: 0.5 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayGeoExp(params.origin, params.scale, params.offset, params.decay, doc['location'].value)" + params: + origin: "40, -70" + scale: "200km" + offset: "0km" + decay: 0.5 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayGeoGauss(params.origin, params.scale, params.offset, params.decay, doc['location'].value)" + params: + origin: "40, -70" + scale: "200km" + offset: "0km" + decay: 0.5 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + +--- +"Decay date functions": + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + mappings: + _doc: + properties: + date: + type: date + - do: + index: + index: test + type: _doc + id: 1 + body: { "date": "2018-01-01T02:00:00Z"} + - do: + index: + index: test + type: _doc + id: 2 + body: { "date": "2018-01-01T01:00:00Z" } + - do: + indices.refresh: {} + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayDateLinear(params.origin, params.scale, params.offset, params.decay, doc['date'].value)" + params: + origin: "2018-01-01T00:00:00Z" + scale: "1h" + offset: "0" + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayDateExp(params.origin, params.scale, params.offset, params.decay, doc['date'].value)" + params: + origin: "2018-01-01T00:00:00Z" + scale: "1h" + offset: "0" + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayDateGauss(params.origin, params.scale, params.offset, params.decay, doc['date'].value)" + params: + origin: "2018-01-01T00:00:00Z" + scale: "1h" + offset: "0" + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + +--- +"Decay numeric functions": + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + mappings: + _doc: + properties: + ival: + type: integer + lval: + type: long + fval: + type: float + dval: + type: double + + - do: + index: + index: test + type: _doc + id: 1 + body: { "ival" : 40, "lval" : 40, "fval": 40.0, "dval": 40.0} + + # for this document, the smallest value in the array is chosen, which will be the closest to the origin + - do: + index: + index: test + type: _doc + id: 2 + body: { "ival" : [50, 40, 20], "lval" : [50, 40, 20], "fval" : [50.0, 40.0, 20.0], "dval" : [50.0, 40.0, 20.0] } + - do: + indices.refresh: {} + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericLinear(params.origin, params.scale, params.offset, params.decay, doc['ival'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericLinear(params.origin, params.scale, params.offset, params.decay, doc['lval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericLinear(params.origin, params.scale, params.offset, params.decay, doc['fval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericLinear(params.origin, params.scale, params.offset, params.decay, doc['dval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericExp(params.origin, params.scale, params.offset, params.decay, doc['ival'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericExp(params.origin, params.scale, params.offset, params.decay, doc['lval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericExp(params.origin, params.scale, params.offset, params.decay, doc['fval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericExp(params.origin, params.scale, params.offset, params.decay, doc['dval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericGauss(params.origin, params.scale, params.offset, params.decay, doc['ival'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericGauss(params.origin, params.scale, params.offset, params.decay, doc['lval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericGauss(params.origin, params.scale, params.offset, params.decay, doc['fval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } + + - do: + search: + index: test + body: + query: + script_score: + query: {match_all: {} } + script: + source: "decayNumericGauss(params.origin, params.scale, params.offset, params.decay, doc['dval'].value)" + params: + origin: 20 + scale: 10 + offset: 0 + decay: 0.9 + - match: { hits.total: 2 } + - match: { hits.hits.0._id : "2" } + - match: { hits.hits.1._id : "1" } diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index 6c55ef2e9343d..3893572aa4494 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -82,16 +82,14 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore); } else { double score = score(docId, subQueryScore.getValue().floatValue()); + // info about params already included in sScript String explanation = "script score function, computed with script:\"" + sScript + "\""; - if (sScript.getParams() != null) { - explanation += " and parameters: \n" + sScript.getParams().toString(); - } Explanation scoreExp = Explanation.match( - subQueryScore.getValue(), "_score: ", - subQueryScore); + subQueryScore.getValue(), "_score: ", + subQueryScore); return Explanation.match( - (float) score, explanation, - scoreExp); + (float) score, explanation, + scoreExp); } return exp; } diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java new file mode 100644 index 0000000000000..481a7f666e913 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java @@ -0,0 +1,167 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.lucene.search.function; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.elasticsearch.ElasticsearchException; + + +import java.io.IOException; +import java.util.Objects; +import java.util.Set; + +/** + * A query that uses a script to compute documents' scores. + */ +public class ScriptScoreQuery extends Query { + final Query subQuery; + final ScriptScoreFunction function; + private final Float minScore; + + public ScriptScoreQuery(Query subQuery, ScriptScoreFunction function, Float minScore) { + this.subQuery = subQuery; + this.function = function; + this.minScore = minScore; + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newQ = subQuery.rewrite(reader); + ScriptScoreFunction newFunction = (ScriptScoreFunction) function.rewrite(reader); + if ((newQ != subQuery) || (newFunction != function)) { + return new ScriptScoreQuery(newQ, newFunction, minScore); + } + return super.rewrite(reader); + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (scoreMode == ScoreMode.COMPLETE_NO_SCORES && minScore == null) { + return subQuery.createWeight(searcher, scoreMode, boost); + } + ScoreMode subQueryScoreMode = function.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; + Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, boost); + + return new Weight(this){ + @Override + public void extractTerms(Set terms) { + subQueryWeight.extractTerms(terms); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Scorer subQueryScorer = subQueryWeight.scorer(context); + if (subQueryScorer == null) { + return null; + } + final LeafScoreFunction leafFunction = function.getLeafScoreFunction(context); + Scorer scriptScorer = new Scorer(this) { + @Override + public float score() throws IOException { + int docId = docID(); + float subQueryScore = subQueryScoreMode == ScoreMode.COMPLETE ? subQueryScorer.score() : 0f; + float score = (float) leafFunction.score(docId, subQueryScore); + if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) { + throw new ElasticsearchException( + "script score query returned an invalid score: " + score + " for doc: " + docId); + } + return score; + } + @Override + public int docID() { + return subQueryScorer.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return subQueryScorer.iterator(); + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; // TODO: what would be a good upper bound? + } + }; + + if (minScore != null) { + scriptScorer = new MinScoreScorer(this, scriptScorer, minScore); + } + return scriptScorer; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + Explanation queryExplanation = subQueryWeight.explain(context, doc); + if (queryExplanation.isMatch() == false) { + return queryExplanation; + } + Explanation explanation = function.getLeafScoreFunction(context).explainScore(doc, queryExplanation); + if (minScore != null && minScore > explanation.getValue().floatValue()) { + explanation = Explanation.noMatch("Score value is too low, expected at least " + minScore + + " but got " + explanation.getValue(), explanation); + } + return explanation; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + // If minScore is not null, then matches depend on statistics of the top-level reader. + return minScore == null; + } + }; + } + + + @Override + public String toString(String field) { + StringBuilder sb = new StringBuilder(); + sb.append("script score (").append(subQuery.toString(field)).append(", function: "); + sb.append("{" + (function == null ? "" : function.toString()) + "}"); + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (sameClassAs(o) == false) { + return false; + } + ScriptScoreQuery other = (ScriptScoreQuery) o; + return Objects.equals(this.subQuery, other.subQuery) && + Objects.equals(this.minScore, other.minScore) && + Objects.equals(this.function, other.function); + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), subQuery, minScore, function); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryBuilders.java b/server/src/main/java/org/elasticsearch/index/query/QueryBuilders.java index ee2172358af16..f5cf2d5da66be 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryBuilders.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryBuilders.java @@ -27,6 +27,8 @@ import org.elasticsearch.index.query.MoreLikeThisQueryBuilder.Item; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreFunctionBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.elasticsearch.indices.TermsLookup; import org.elasticsearch.script.Script; @@ -435,6 +437,17 @@ public static FunctionScoreQueryBuilder functionScoreQuery(QueryBuilder queryBui return (new FunctionScoreQueryBuilder(queryBuilder, function)); } + /** + * A query that allows to define a custom scoring function through script. + * + * @param queryBuilder The query to custom score + * @param function The script score function builder used to custom score + */ + public static ScriptScoreQueryBuilder scriptScoreQuery(QueryBuilder queryBuilder, ScriptScoreFunctionBuilder function) { + return new ScriptScoreQueryBuilder(queryBuilder, function); + } + + /** * A more like this query that finds documents that are "like" the provided texts or documents * which is checked against the fields the query is constructed with. diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScoreFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScoreFunctionBuilder.java index 6cfe7d177da79..bd6acd9f09ff8 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScoreFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScoreFunctionBuilder.java @@ -102,7 +102,7 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException; @Override - public final String getWriteableName() { + public String getWriteableName() { return getName(); } @@ -116,8 +116,7 @@ public final boolean equals(Object obj) { } @SuppressWarnings("unchecked") FB other = (FB) obj; - return Objects.equals(weight, other.getWeight()) && - doEquals(other); + return Objects.equals(weight, other.getWeight()) && doEquals(other); } /** diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreQueryBuilder.java new file mode 100644 index 0000000000000..fb53f1c9560cc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreQueryBuilder.java @@ -0,0 +1,187 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query.functionscore; + +import org.apache.lucene.search.Query; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.script.Script; +import org.elasticsearch.common.lucene.search.function.ScriptScoreFunction; +import org.elasticsearch.common.lucene.search.function.ScriptScoreQuery; +import org.elasticsearch.index.query.InnerHitContextBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.QueryShardContext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * A query that computes a document score based on the provided script + */ +public class ScriptScoreQueryBuilder extends AbstractQueryBuilder { + + public static final String NAME = "script_score"; + public static final ParseField QUERY_FIELD = new ParseField("query"); + public static final ParseField SCRIPT_FIELD = new ParseField("script"); + public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); + + private static ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, false, + args -> { + ScriptScoreFunctionBuilder ssFunctionBuilder = new ScriptScoreFunctionBuilder((Script) args[1]); + ScriptScoreQueryBuilder ssQueryBuilder = new ScriptScoreQueryBuilder((QueryBuilder) args[0], ssFunctionBuilder); + if (args[2] != null) ssQueryBuilder.setMinScore((Float) args[2]); + if (args[3] != null) ssQueryBuilder.boost((Float) args[3]); + if (args[4] != null) ssQueryBuilder.queryName((String) args[4]); + return ssQueryBuilder; + }); + + static { + PARSER.declareObject(constructorArg(), (p,c) -> parseInnerQueryBuilder(p), QUERY_FIELD); + PARSER.declareObject(constructorArg(), (p,c) -> Script.parse(p), SCRIPT_FIELD); + PARSER.declareFloat(optionalConstructorArg(), MIN_SCORE_FIELD); + PARSER.declareFloat(optionalConstructorArg(), AbstractQueryBuilder.BOOST_FIELD); + PARSER.declareString(optionalConstructorArg(), AbstractQueryBuilder.NAME_FIELD); + } + + public static ScriptScoreQueryBuilder fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final QueryBuilder query; + private Float minScore = null; + private final ScriptScoreFunctionBuilder scriptScoreFunctionBuilder; + + + /** + * Creates a script_score query that executes the provided script function on documents that match a query. + * + * @param query the query that defines which documents the script_score query will be executed on. + * @param scriptScoreFunctionBuilder defines script function + */ + public ScriptScoreQueryBuilder(QueryBuilder query, ScriptScoreFunctionBuilder scriptScoreFunctionBuilder) { + // require the supply of the query, even the explicit supply of "match_all" query + if (query == null) { + throw new IllegalArgumentException("script_score: query must not be null"); + } + if (scriptScoreFunctionBuilder == null) { + throw new IllegalArgumentException("script_score: script must not be null"); + } + this.query = query; + this.scriptScoreFunctionBuilder = scriptScoreFunctionBuilder; + } + + /** + * Read from a stream. + */ + public ScriptScoreQueryBuilder(StreamInput in) throws IOException { + super(in); + query = in.readNamedWriteable(QueryBuilder.class); + scriptScoreFunctionBuilder = in.readNamedWriteable(ScriptScoreFunctionBuilder.class); + minScore = in.readOptionalFloat(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(query); + out.writeNamedWriteable(scriptScoreFunctionBuilder); + out.writeOptionalFloat(minScore); + } + + /** + * Returns the query builder that defines which documents the script_score query will be executed on. + */ + public QueryBuilder query() { + return this.query; + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field(QUERY_FIELD.getPreferredName()); + query.toXContent(builder, params); + builder.field(SCRIPT_FIELD.getPreferredName(), scriptScoreFunctionBuilder.getScript()); + if (minScore != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); + } + printBoostAndQueryName(builder); + builder.endObject(); + } + + public ScriptScoreQueryBuilder setMinScore(float minScore) { + this.minScore = minScore; + return this; + } + + public Float getMinScore() { + return this.minScore; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected boolean doEquals(ScriptScoreQueryBuilder other) { + return Objects.equals(this.query, other.query) && + Objects.equals(this.scriptScoreFunctionBuilder, other.scriptScoreFunctionBuilder) && + Objects.equals(this.minScore, other.minScore) ; + } + + @Override + protected int doHashCode() { + return Objects.hash(this.query, this.scriptScoreFunctionBuilder, this.minScore); + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + ScriptScoreFunction function = (ScriptScoreFunction) scriptScoreFunctionBuilder.toFunction(context); + Query query = this.query.toQuery(context); + return new ScriptScoreQuery(query, function, minScore); + } + + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + QueryBuilder newQuery = this.query.rewrite(queryRewriteContext); + if (newQuery != query) { + ScriptScoreQueryBuilder newQueryBuilder = new ScriptScoreQueryBuilder(newQuery, scriptScoreFunctionBuilder); + newQueryBuilder.setMinScore(minScore); + return newQueryBuilder; + } + return this; + } + + @Override + protected void extractInnerHitBuilders(Map innerHits) { + InnerHitContextBuilder.extractInnerHits(query(), innerHits); + } + +} diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/ScoreScriptUtils.java new file mode 100644 index 0000000000000..892d921091e37 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/ScoreScriptUtils.java @@ -0,0 +1,276 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.script; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.StringHelper; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.geo.GeoDistance; +import org.elasticsearch.common.geo.GeoPoint; +import org.elasticsearch.common.geo.GeoUtils; +import org.elasticsearch.common.joda.JodaDateMathParser; +import org.elasticsearch.common.unit.DistanceUnit; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.mapper.DateFieldMapper; + +import java.time.ZoneId; +import java.util.Random; + +/** + * ScoringScriptImpl can be used as {@link ScoreScript} + * to run a previously compiled Painless script. + */ +public final class ScoreScriptUtils { + + /****** STATIC FUNCTIONS that can be used by users for score calculations **/ + + public static double rational(double value, double k) { + return value/ (k + value); + } + + /** + * Calculate a sigmoid of value + * with scaling parameters k and a + */ + public static double sigmoid(double value, double k, double a){ + return Math.pow(value,a) / (Math.pow(k,a) + Math.pow(value,a)); + } + + + // reproducible random + public static double randomReproducible(String seedValue, int seed) { + int hash = StringHelper.murmurhash3_x86_32(new BytesRef(seedValue), seed); + return (hash & 0x00FFFFFF) / (float)(1 << 24); // only use the lower 24 bits to construct a float from 0.0-1.0 + } + + // not reproducible random + public static final class RandomNotReproducible { + private final Random rnd; + + public RandomNotReproducible() { + this.rnd = Randomness.get(); + } + + public double randomNotReproducible() { + return rnd.nextDouble(); + } + } + + + // **** Decay functions on geo field + public static final class DecayGeoLinear { + // cached variables calculated once per script execution + double originLat; + double originLon; + double offset; + double scaling; + + public DecayGeoLinear(String originStr, String scaleStr, String offsetStr, double decay) { + GeoPoint origin = GeoUtils.parseGeoPoint(originStr, false); + double scale = DistanceUnit.DEFAULT.parse(scaleStr, DistanceUnit.DEFAULT); + this.originLat = origin.lat(); + this.originLon = origin.lon(); + this.offset = DistanceUnit.DEFAULT.parse(offsetStr, DistanceUnit.DEFAULT); + this.scaling = scale / (1.0 - decay); + } + + public double decayGeoLinear(GeoPoint docValue) { + double distance = GeoDistance.ARC.calculate(originLat, originLon, docValue.lat(), docValue.lon(), DistanceUnit.METERS); + distance = Math.max(0.0d, distance - offset); + return Math.max(0.0, (scaling - distance) / scaling); + } + } + + public static final class DecayGeoExp { + double originLat; + double originLon; + double offset; + double scaling; + + public DecayGeoExp(String originStr, String scaleStr, String offsetStr, double decay) { + GeoPoint origin = GeoUtils.parseGeoPoint(originStr, false); + double scale = DistanceUnit.DEFAULT.parse(scaleStr, DistanceUnit.DEFAULT); + this.originLat = origin.lat(); + this.originLon = origin.lon(); + this.offset = DistanceUnit.DEFAULT.parse(offsetStr, DistanceUnit.DEFAULT); + this.scaling = Math.log(decay) / scale; + } + + public double decayGeoExp(GeoPoint docValue) { + double distance = GeoDistance.ARC.calculate(originLat, originLon, docValue.lat(), docValue.lon(), DistanceUnit.METERS); + distance = Math.max(0.0d, distance - offset); + return Math.exp(scaling * distance); + } + } + + public static final class DecayGeoGauss { + double originLat; + double originLon; + double offset; + double scaling; + + public DecayGeoGauss(String originStr, String scaleStr, String offsetStr, double decay) { + GeoPoint origin = GeoUtils.parseGeoPoint(originStr, false); + double scale = DistanceUnit.DEFAULT.parse(scaleStr, DistanceUnit.DEFAULT); + this.originLat = origin.lat(); + this.originLon = origin.lon(); + this.offset = DistanceUnit.DEFAULT.parse(offsetStr, DistanceUnit.DEFAULT); + this.scaling = 0.5 * Math.pow(scale, 2.0) / Math.log(decay);; + } + + public double decayGeoGauss(GeoPoint docValue) { + double distance = GeoDistance.ARC.calculate(originLat, originLon, docValue.lat(), docValue.lon(), DistanceUnit.METERS); + distance = Math.max(0.0d, distance - offset); + return Math.exp(0.5 * Math.pow(distance, 2.0) / scaling); + } + } + + // **** Decay functions on numeric field + + public static final class DecayNumericLinear { + double origin; + double offset; + double scaling; + + public DecayNumericLinear(double origin, double scale, double offset, double decay) { + this.origin = origin; + this.offset = offset; + this.scaling = scale / (1.0 - decay); + } + + public double decayNumericLinear(double docValue) { + double distance = Math.max(0.0d, Math.abs(docValue - origin) - offset); + return Math.max(0.0, (scaling - distance) / scaling); + } + } + + public static final class DecayNumericExp { + double origin; + double offset; + double scaling; + + public DecayNumericExp(double origin, double scale, double offset, double decay) { + this.origin = origin; + this.offset = offset; + this.scaling = Math.log(decay) / scale; + } + + public double decayNumericExp(double docValue) { + double distance = Math.max(0.0d, Math.abs(docValue - origin) - offset); + return Math.exp(scaling * distance); + } + } + + public static final class DecayNumericGauss { + double origin; + double offset; + double scaling; + + public DecayNumericGauss(double origin, double scale, double offset, double decay) { + this.origin = origin; + this.offset = offset; + this.scaling = 0.5 * Math.pow(scale, 2.0) / Math.log(decay); + } + + public double decayNumericGauss(double docValue) { + double distance = Math.max(0.0d, Math.abs(docValue - origin) - offset); + return Math.exp(0.5 * Math.pow(distance, 2.0) / scaling); + } + } + + // **** Decay functions on date field + + /** + * Limitations: since script functions don't have access to DateFieldMapper, + * decay functions on dates are limited to dates in the default format and default time zone, + * Also, using calculations with now are not allowed. + * + */ + private static final ZoneId defaultZoneId = ZoneId.of("UTC"); + private static final JodaDateMathParser dateParser = new JodaDateMathParser(DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER); + + public static final class DecayDateLinear { + long origin; + long offset; + double scaling; + + public DecayDateLinear(String originStr, String scaleStr, String offsetStr, double decay) { + this.origin = dateParser.parse(originStr, null, false, defaultZoneId); + long scale = TimeValue.parseTimeValue(scaleStr, TimeValue.timeValueHours(24), getClass().getSimpleName() + ".scale") + .getMillis(); + this.offset = TimeValue.parseTimeValue(offsetStr, TimeValue.timeValueHours(24), getClass().getSimpleName() + ".offset") + .getMillis(); + this.scaling = scale / (1.0 - decay); + } + + public double decayDateLinear(JodaCompatibleZonedDateTime docValueDate) { + long docValue = docValueDate.toInstant().toEpochMilli(); + // as java.lang.Math#abs(long) is a forbidden API, have to use this comparison instead + long diff = (docValue >= origin) ? (docValue - origin) : (origin - docValue); + long distance = Math.max(0, diff - offset); + return Math.max(0.0, (scaling - distance) / scaling); + } + } + + public static final class DecayDateExp { + long origin; + long offset; + double scaling; + + public DecayDateExp(String originStr, String scaleStr, String offsetStr, double decay) { + this.origin = dateParser.parse(originStr, null, false, defaultZoneId); + long scale = TimeValue.parseTimeValue(scaleStr, TimeValue.timeValueHours(24), getClass().getSimpleName() + ".scale") + .getMillis(); + this.offset = TimeValue.parseTimeValue(offsetStr, TimeValue.timeValueHours(24), getClass().getSimpleName() + ".offset") + .getMillis(); + this.scaling = Math.log(decay) / scale; + } + + public double decayDateExp(JodaCompatibleZonedDateTime docValueDate) { + long docValue = docValueDate.toInstant().toEpochMilli(); + long diff = (docValue >= origin) ? (docValue - origin) : (origin - docValue); + long distance = Math.max(0, diff - offset); + return Math.exp(scaling * distance); + } + } + + + public static final class DecayDateGauss { + long origin; + long offset; + double scaling; + + public DecayDateGauss(String originStr, String scaleStr, String offsetStr, double decay) { + this.origin = dateParser.parse(originStr, null, false, defaultZoneId); + long scale = TimeValue.parseTimeValue(scaleStr, TimeValue.timeValueHours(24), getClass().getSimpleName() + ".scale") + .getMillis(); + this.offset = TimeValue.parseTimeValue(offsetStr, TimeValue.timeValueHours(24), getClass().getSimpleName() + ".offset") + .getMillis(); + this.scaling = 0.5 * Math.pow(scale, 2.0) / Math.log(decay); + } + + public double decayDateGauss(JodaCompatibleZonedDateTime docValueDate) { + long docValue = docValueDate.toInstant().toEpochMilli(); + long diff = (docValue >= origin) ? (docValue - origin) : (origin - docValue); + long distance = Math.max(0, diff - offset); + return Math.exp(0.5 * Math.pow(distance, 2.0) / scaling); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 539f2de529f23..66e97230636e8 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -81,6 +81,7 @@ import org.elasticsearch.index.query.functionscore.RandomScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScriptScoreFunctionBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.elasticsearch.index.query.functionscore.WeightBuilder; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SearchPlugin.AggregationSpec; @@ -634,8 +635,12 @@ private Map setupHighlighters(Settings settings, List plugins) { + // ScriptScoreFunctionBuilder has it own named writable because of a new script_score query + namedWriteables.add(new NamedWriteableRegistry.Entry( + ScriptScoreFunctionBuilder.class, ScriptScoreFunctionBuilder.NAME, ScriptScoreFunctionBuilder::new)); registerScoreFunction(new ScoreFunctionSpec<>(ScriptScoreFunctionBuilder.NAME, ScriptScoreFunctionBuilder::new, ScriptScoreFunctionBuilder::fromXContent)); + registerScoreFunction( new ScoreFunctionSpec<>(GaussDecayFunctionBuilder.NAME, GaussDecayFunctionBuilder::new, GaussDecayFunctionBuilder.PARSER)); registerScoreFunction(new ScoreFunctionSpec<>(LinearDecayFunctionBuilder.NAME, LinearDecayFunctionBuilder::new, @@ -786,6 +791,7 @@ private void registerQueryParsers(List plugins) { new QuerySpec<>(SpanMultiTermQueryBuilder.NAME, SpanMultiTermQueryBuilder::new, SpanMultiTermQueryBuilder::fromXContent)); registerQuery(new QuerySpec<>(FunctionScoreQueryBuilder.NAME, FunctionScoreQueryBuilder::new, FunctionScoreQueryBuilder::fromXContent)); + registerQuery(new QuerySpec<>(ScriptScoreQueryBuilder.NAME, ScriptScoreQueryBuilder::new, ScriptScoreQueryBuilder::fromXContent)); registerQuery( new QuerySpec<>(SimpleQueryStringBuilder.NAME, SimpleQueryStringBuilder::new, SimpleQueryStringBuilder::fromXContent)); registerQuery(new QuerySpec<>(TypeQueryBuilder.NAME, TypeQueryBuilder::new, TypeQueryBuilder::fromXContent)); diff --git a/server/src/test/java/org/elasticsearch/index/query/ScriptScoreQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/ScriptScoreQueryBuilderTests.java new file mode 100644 index 0000000000000..ef173883d0ac0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/query/ScriptScoreQueryBuilderTests.java @@ -0,0 +1,95 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.Query; +import org.elasticsearch.common.lucene.search.function.ScriptScoreQuery; +import org.elasticsearch.index.query.functionscore.ScriptScoreFunctionBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; +import org.elasticsearch.script.MockScriptEngine; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.test.AbstractQueryTestCase; + +import java.io.IOException; +import java.util.Collections; + +import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; +import static org.hamcrest.CoreMatchers.instanceOf; + +public class ScriptScoreQueryBuilderTests extends AbstractQueryTestCase { + + @Override + protected ScriptScoreQueryBuilder doCreateTestQueryBuilder() { + String scriptStr = "1"; + Script script = new Script(ScriptType.INLINE, MockScriptEngine.NAME, scriptStr, Collections.emptyMap()); + ScriptScoreQueryBuilder queryBuilder = new ScriptScoreQueryBuilder( + RandomQueryBuilder.createQuery(random()), + new ScriptScoreFunctionBuilder(script) + ); + if (randomBoolean()) { + queryBuilder.setMinScore(randomFloat()); + } + return queryBuilder; + } + + @Override + protected void doAssertLuceneQuery(ScriptScoreQueryBuilder queryBuilder, Query query, SearchContext context) throws IOException { + assertThat(query, instanceOf(ScriptScoreQuery.class)); + } + + public void testFromJson() throws IOException { + String json = + "{\n" + + " \"script_score\" : {\n" + + " \"query\" : { \"match_all\" : {} },\n" + + " \"script\" : {\n" + + " \"source\" : \"doc['field'].value\" \n" + + " },\n" + + " \"min_score\" : 2.0\n" + + " }\n" + + "}"; + + ScriptScoreQueryBuilder parsed = (ScriptScoreQueryBuilder) parseQuery(json); + assertEquals(json, 2, parsed.getMinScore(), 0.0001); + } + + public void testIllegalArguments() { + String scriptStr = "1"; + Script script = new Script(ScriptType.INLINE, MockScriptEngine.NAME, scriptStr, Collections.emptyMap()); + ScriptScoreFunctionBuilder functionBuilder = new ScriptScoreFunctionBuilder(script); + + expectThrows( + IllegalArgumentException.class, + () -> new ScriptScoreQueryBuilder(matchAllQuery(), null) + ); + + expectThrows( + IllegalArgumentException.class, + () -> new ScriptScoreQueryBuilder(null, functionBuilder) + ); + } + + @Override + protected boolean isCachable(ScriptScoreQueryBuilder queryBuilder) { + return false; + } +} diff --git a/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java b/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java index ca20e6ec4788d..cf5b3fc0fc13b 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java @@ -341,6 +341,7 @@ public List> getRescorers() { "range", "regexp", "script", + "script_score", "simple_query_string", "span_containing", "span_first", diff --git a/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryIT.java b/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryIT.java new file mode 100644 index 0000000000000..0e1d16e100afe --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryIT.java @@ -0,0 +1,105 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.query; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.index.query.functionscore.ScriptScoreFunctionBuilder; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.test.ESIntegTestCase; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import static org.elasticsearch.index.query.QueryBuilders.matchQuery; +import static org.elasticsearch.index.query.QueryBuilders.scriptScoreQuery; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFirstHit; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertOrderedSearchHits; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSecondHit; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThirdHit; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore; + +public class ScriptScoreQueryIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(CustomScriptPlugin.class); + } + + public static class CustomScriptPlugin extends MockScriptPlugin { + @Override + protected Map, Object>> pluginScripts() { + Map, Object>> scripts = new HashMap<>(); + scripts.put("doc['field2'].value * param1", vars -> { + Map doc = (Map) vars.get("doc"); + ScriptDocValues.Doubles field2Values = (ScriptDocValues.Doubles) doc.get("field2"); + Double param1 = (Double) vars.get("param1"); + return field2Values.getValue() * param1; + }); + return scripts; + } + } + + // test that script_score works as expected: + // 1) only matched docs retrieved + // 2) score is calculated based on a script with params + // 3) min score applied + public void testScriptScore() { + assertAcked( + prepareCreate("test-index").addMapping("_doc", "field1", "type=text", "field2", "type=double") + ); + int docCount = 10; + for (int i = 1; i <= docCount; i++) { + client().prepareIndex("test-index", "_doc", "" + i) + .setSource("field1", "text" + (i % 2), "field2", i ) + .get(); + } + refresh(); + + Map params = new HashMap<>(); + params.put("param1", 0.1); + Script script = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "doc['field2'].value * param1", params); + SearchResponse resp = client() + .prepareSearch("test-index") + .setQuery(scriptScoreQuery(matchQuery("field1", "text0"), new ScriptScoreFunctionBuilder(script))) + .get(); + assertNoFailures(resp); + assertOrderedSearchHits(resp, "10", "8", "6", "4", "2"); + assertFirstHit(resp, hasScore(1.0f)); + assertSecondHit(resp, hasScore(0.8f)); + assertThirdHit(resp, hasScore(0.6f)); + + // applying min score + resp = client() + .prepareSearch("test-index") + .setQuery(scriptScoreQuery(matchQuery("field1", "text0"), new ScriptScoreFunctionBuilder(script)).setMinScore(0.6f)) + .get(); + assertNoFailures(resp); + assertOrderedSearchHits(resp, "10", "8", "6"); + } +}