From 974a449cfa35d152efe90352c9f81056f8558713 Mon Sep 17 00:00:00 2001 From: Winston Ewert Date: Mon, 24 Apr 2017 08:44:55 -0700 Subject: [PATCH 1/2] Fixes #24259 Corrects the ScriptedMetricAggregator so that the script can have access to scores during the map stage. --- .../metrics/scripted/ScriptedMetricAggregator.java | 2 +- .../scripted/ScriptedMetricAggregatorTests.java | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index 3bfc057682a2a..cee7b3402f3e4 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -63,7 +63,7 @@ public boolean needsScores() { public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { final LeafSearchScript leafMapScript = mapScript.getLeafSearchScript(ctx); - return new LeafBucketCollectorBase(sub, mapScript) { + return new LeafBucketCollectorBase(sub, leafMapScript) { @Override public void collect(int doc, long bucket) throws IOException { assert bucket == 0 : bucket; diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index d9eb76310d241..750a8bbceba95 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.script.MockScriptEngine; +import org.elasticsearch.script.ScoreAccessor; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptContextRegistry; import org.elasticsearch.script.ScriptEngineRegistry; @@ -67,17 +68,17 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { public static void initMockScripts() { SCRIPTS.put("initScript", params -> { Map agg = (Map) params.get("_agg"); - agg.put("collector", new ArrayList()); + agg.put("collector", new ArrayList()); return agg; }); SCRIPTS.put("mapScript", params -> { Map agg = (Map) params.get("_agg"); - ((List) agg.get("collector")).add(1); // just add 1 for each doc the script is run on + ((List) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue()); return agg; }); SCRIPTS.put("combineScript", params -> { Map agg = (Map) params.get("_agg"); - return ((List) agg.get("collector")).stream().mapToInt(Integer::intValue).sum(); + return ((List) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum(); }); } @@ -117,7 +118,7 @@ public void testScriptedMetricWithoutCombine() throws IOException { assertEquals(AGG_NAME, scriptedMetric.getName()); assertNotNull(scriptedMetric.aggregation()); Map agg = (Map) scriptedMetric.aggregation(); - assertEquals(numDocs, ((List) agg.get("collector")).size()); + assertEquals(numDocs, ((List) agg.get("collector")).size()); } } } @@ -139,7 +140,7 @@ public void testScriptedMetricWithCombine() throws IOException { ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder); assertEquals(AGG_NAME, scriptedMetric.getName()); assertNotNull(scriptedMetric.aggregation()); - assertEquals(numDocs, scriptedMetric.aggregation()); + assertEquals((double) numDocs, scriptedMetric.aggregation()); } } } From 982feae498e8c9366b2384814d0fc15a32c97ed9 Mon Sep 17 00:00:00 2001 From: Winston Ewert Date: Mon, 24 Apr 2017 10:43:14 -0700 Subject: [PATCH 2/2] Restored original tests. Added seperate test. As requested, I've restored the non-score dependant tests, and added the score dependent metric as a seperate test. --- .../ScriptedMetricAggregatorTests.java | 49 +++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index 750a8bbceba95..5dcae47da43b1 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -60,6 +60,11 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { private static final Script MAP_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScript", Collections.emptyMap()); private static final Script COMBINE_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScript", Collections.emptyMap()); + + private static final Script INIT_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptScore", Collections.emptyMap()); + private static final Script MAP_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptScore", Collections.emptyMap()); + private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore", + Collections.emptyMap()); private static final Map, Object>> SCRIPTS = new HashMap<>(); @@ -68,15 +73,30 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { public static void initMockScripts() { SCRIPTS.put("initScript", params -> { Map agg = (Map) params.get("_agg"); - agg.put("collector", new ArrayList()); + agg.put("collector", new ArrayList()); return agg; }); SCRIPTS.put("mapScript", params -> { Map agg = (Map) params.get("_agg"); - ((List) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue()); + ((List) agg.get("collector")).add(1); // just add 1 for each doc the script is run on return agg; }); SCRIPTS.put("combineScript", params -> { + Map agg = (Map) params.get("_agg"); + return ((List) agg.get("collector")).stream().mapToInt(Integer::intValue).sum(); + }); + + SCRIPTS.put("initScriptScore", params -> { + Map agg = (Map) params.get("_agg"); + agg.put("collector", new ArrayList()); + return agg; + }); + SCRIPTS.put("mapScriptScore", params -> { + Map agg = (Map) params.get("_agg"); + ((List) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue()); + return agg; + }); + SCRIPTS.put("combineScriptScore", params -> { Map agg = (Map) params.get("_agg"); return ((List) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum(); }); @@ -118,7 +138,7 @@ public void testScriptedMetricWithoutCombine() throws IOException { assertEquals(AGG_NAME, scriptedMetric.getName()); assertNotNull(scriptedMetric.aggregation()); Map agg = (Map) scriptedMetric.aggregation(); - assertEquals(numDocs, ((List) agg.get("collector")).size()); + assertEquals(numDocs, ((List) agg.get("collector")).size()); } } } @@ -140,6 +160,29 @@ public void testScriptedMetricWithCombine() throws IOException { ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder); assertEquals(AGG_NAME, scriptedMetric.getName()); assertNotNull(scriptedMetric.aggregation()); + assertEquals(numDocs, scriptedMetric.aggregation()); + } + } + } + + /** + * test that uses the score of the documents + */ + public void testScriptedMetricWithCombineAccessesScores() throws IOException { + try (Directory directory = newDirectory()) { + Integer numDocs = randomInt(100); + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < numDocs; i++) { + indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i))); + } + } + try (IndexReader indexReader = DirectoryReader.open(directory)) { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + aggregationBuilder.initScript(INIT_SCRIPT_SCORE).mapScript(MAP_SCRIPT_SCORE).combineScript(COMBINE_SCRIPT_SCORE); + ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder); + assertEquals(AGG_NAME, scriptedMetric.getName()); + assertNotNull(scriptedMetric.aggregation()); + // all documents have score of 1.0 assertEquals((double) numDocs, scriptedMetric.aggregation()); } }