From 2710d5a5220f064082de6e9c553b805ea243adc2 Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Mon, 2 Apr 2018 13:49:09 -0700 Subject: [PATCH 1/9] Migrate scripted metric aggregation scripts to ScriptContext design #29328 --- .../painless/MetricAggScriptsTests.java | 113 ++++++++++++++ .../script/MetricAggScripts.java | 140 ++++++++++++++++++ .../elasticsearch/script/ScriptModule.java | 6 +- .../scripted/InternalScriptedMetric.java | 17 ++- .../ScriptedMetricAggregationBuilder.java | 24 +-- .../scripted/ScriptedMetricAggregator.java | 29 +++- .../ScriptedMetricAggregatorFactory.java | 28 ++-- .../metrics/ScriptedMetricIT.java | 79 +++++++++- .../ScriptedMetricAggregatorTests.java | 2 +- .../script/MockScriptEngine.java | 135 ++++++++++++++++- 10 files changed, 527 insertions(+), 46 deletions(-) create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java create mode 100644 server/src/main/java/org/elasticsearch/script/MetricAggScripts.java diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java new file mode 100644 index 0000000000000..02d8250a664f8 --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java @@ -0,0 +1,113 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.painless; + +import org.elasticsearch.painless.spi.Whitelist; +import org.elasticsearch.script.MetricAggScripts; +import org.elasticsearch.script.ScriptContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MetricAggScriptsTests extends ScriptTestCase { + @Override + protected Map, List> scriptContexts() { + Map, List> contexts = new HashMap<>(); + contexts.put(MetricAggScripts.InitScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(MetricAggScripts.MapScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(MetricAggScripts.CombineScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(MetricAggScripts.ReduceScript.CONTEXT, Whitelist.BASE_WHITELISTS); + return contexts; + } + + public void testInitBasic() { + MetricAggScripts.InitScript.Factory factory = scriptEngine.compile("test", + "agg.testField = params.initialVal", MetricAggScripts.InitScript.CONTEXT, Collections.emptyMap()); + + Map params = new HashMap<>(); + Map agg = new HashMap<>(); + + params.put("initialVal", 10); + + MetricAggScripts.InitScript script = factory.newInstance(params, agg); + script.execute(); + + assert(agg.containsKey("testField")); + assertEquals(10, agg.get("testField")); + } + + public void testMapBasic() { + MetricAggScripts.MapScript.Factory factory = scriptEngine.compile("test", + "agg.testField = 2*_score", MetricAggScripts.MapScript.CONTEXT, Collections.emptyMap()); + + Map params = new HashMap<>(); + Map agg = new HashMap<>(); + double _score = 0.5; + + MetricAggScripts.MapScript.LeafFactory leafFactory = factory.newFactory(params, agg, null); + MetricAggScripts.MapScript script = leafFactory.newInstance(null); + + script.execute(_score); + + assert(agg.containsKey("testField")); + assertEquals(1.0, agg.get("testField")); + } + + public void testCombineBasic() { + MetricAggScripts.CombineScript.Factory factory = scriptEngine.compile("test", + "agg.testField = params.initialVal; return agg.testField + params.inc", MetricAggScripts.CombineScript.CONTEXT, + Collections.emptyMap()); + + Map params = new HashMap<>(); + Map agg = new HashMap<>(); + + params.put("initialVal", 10); + params.put("inc", 2); + + MetricAggScripts.CombineScript script = factory.newInstance(params, agg); + Object res = script.execute(); + + assert(agg.containsKey("testField")); + assertEquals(10, agg.get("testField")); + assertEquals(12, res); + } + + public void testReduceBasic() { + MetricAggScripts.ReduceScript.Factory factory = scriptEngine.compile("test", + "aggs[0].testField + aggs[1].testField", MetricAggScripts.ReduceScript.CONTEXT, Collections.emptyMap()); + + Map params = new HashMap<>(); + List aggs = new ArrayList<>(); + + Map agg1 = new HashMap<>(), agg2 = new HashMap<>(); + agg1.put("testField", 1); + agg2.put("testField", 2); + + aggs.add(agg1); + aggs.add(agg2); + + MetricAggScripts.ReduceScript script = factory.newInstance(params, aggs); + Object res = script.execute(); + assertEquals(3, res); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/MetricAggScripts.java b/server/src/main/java/org/elasticsearch/script/MetricAggScripts.java new file mode 100644 index 0000000000000..4d44b6f78e250 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/MetricAggScripts.java @@ -0,0 +1,140 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.script; + +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.search.lookup.LeafSearchLookup; +import org.elasticsearch.search.lookup.SearchLookup; + +import java.util.List; +import java.util.Map; + +public class MetricAggScripts { + private abstract static class ParamsAndAggBase { + private final Map params; + private final Object agg; + + ParamsAndAggBase(Map params, Object agg) { + this.params = params; + this.agg = agg; + } + + public Map getParams() { + return params; + } + + public Object getAgg() { + return agg; + } + } + + public abstract static class InitScript extends ParamsAndAggBase { + public InitScript(Map params, Object agg) { + super(params, agg); + } + + public abstract void execute(); + + public interface Factory { + InitScript newInstance(Map params, Object agg); + } + + public static String[] PARAMETERS = {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_init", Factory.class); + } + + public abstract static class MapScript extends ParamsAndAggBase { + private final LeafSearchLookup leafLookup; + + public MapScript(Map params, Object agg, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, agg); + + this.leafLookup = leafContext == null ? null : lookup.getLeafSearchLookup(leafContext); + } + + // Return the doc as a map (instead of LeafDocLookup) in order to abide by type whitelisting rules for + // Painless scripts. + public Map> getDoc() { + return leafLookup == null ? null : leafLookup.doc(); + } + + public void setDocument(int docId) { + if (leafLookup != null) { + leafLookup.setDocument(docId); + } + } + + public abstract void execute(double _score); + + public interface LeafFactory { + MapScript newInstance(LeafReaderContext ctx); + } + + public interface Factory { + LeafFactory newFactory(Map params, Object agg, SearchLookup lookup); + } + + public static String[] PARAMETERS = new String[] {"_score"}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_map", Factory.class); + } + + public abstract static class CombineScript extends ParamsAndAggBase { + public CombineScript(Map params, Object agg) { + super(params, agg); + } + + public abstract Object execute(); + + public interface Factory { + CombineScript newInstance(Map params, Object agg); + } + + public static String[] PARAMETERS = {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_combine", Factory.class); + } + + public abstract static class ReduceScript { + private final Map params; + private final List aggs; + + public ReduceScript(Map params, List aggs) { + this.params = params; + this.aggs = aggs; + } + + public Map getParams() { + return params; + } + + public List getAggs() { + return aggs; + } + + public abstract Object execute(); + + public interface Factory { + ReduceScript newInstance(Map params, List aggs); + } + + public static String[] PARAMETERS = {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_reduce", Factory.class); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/ScriptModule.java b/server/src/main/java/org/elasticsearch/script/ScriptModule.java index 727651be6a565..903dcf5831a28 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptModule.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptModule.java @@ -48,7 +48,11 @@ public class ScriptModule { FilterScript.CONTEXT, SimilarityScript.CONTEXT, SimilarityWeightScript.CONTEXT, - TemplateScript.CONTEXT + TemplateScript.CONTEXT, + MetricAggScripts.InitScript.CONTEXT, + MetricAggScripts.MapScript.CONTEXT, + MetricAggScripts.CombineScript.CONTEXT, + MetricAggScripts.ReduceScript.CONTEXT ).collect(Collectors.toMap(c -> c.name, Function.identity())); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java index 6f9a6fe5d9774..b151e96a9dafe 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java @@ -22,7 +22,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.MetricAggScripts; import org.elasticsearch.script.Script; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; @@ -89,15 +89,16 @@ public InternalAggregation doReduce(List aggregations, Redu InternalScriptedMetric firstAggregation = ((InternalScriptedMetric) aggregations.get(0)); List aggregation; if (firstAggregation.reduceScript != null && reduceContext.isFinalReduce()) { - Map vars = new HashMap<>(); - vars.put("_aggs", aggregationObjects); + Map params = new HashMap<>(); if (firstAggregation.reduceScript.getParams() != null) { - vars.putAll(firstAggregation.reduceScript.getParams()); + params.putAll(firstAggregation.reduceScript.getParams()); } - ExecutableScript.Factory factory = reduceContext.scriptService().compile( - firstAggregation.reduceScript, ExecutableScript.AGGS_CONTEXT); - ExecutableScript script = factory.newInstance(vars); - aggregation = Collections.singletonList(script.run()); + params.put("_aggs", aggregationObjects); + + MetricAggScripts.ReduceScript.Factory factory = reduceContext.scriptService().compile( + firstAggregation.reduceScript, MetricAggScripts.ReduceScript.CONTEXT); + MetricAggScripts.ReduceScript script = factory.newInstance(params, aggregationObjects); + aggregation = Collections.singletonList(script.execute()); } else if (reduceContext.isFinalReduce()) { aggregation = Collections.singletonList(aggregationObjects); } else { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java index c11c68f9b2524..7963ba212b7e9 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java @@ -26,12 +26,10 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryShardContext; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.MetricAggScripts; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories.Builder; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.internal.SearchContext; @@ -203,30 +201,32 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega // Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have // access to them for the scripts it's given precompiled. - ExecutableScript.Factory executableInitScript; + MetricAggScripts.InitScript.Factory compiledInitScript; Map initScriptParams; if (initScript != null) { - executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT); + compiledInitScript = queryShardContext.getScriptService().compile(initScript, MetricAggScripts.InitScript.CONTEXT); initScriptParams = initScript.getParams(); } else { - executableInitScript = p -> null; + compiledInitScript = (p, a) -> null; initScriptParams = Collections.emptyMap(); } - SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT); + MetricAggScripts.MapScript.Factory compiledMapScript = queryShardContext.getScriptService().compile(mapScript, + MetricAggScripts.MapScript.CONTEXT); Map mapScriptParams = mapScript.getParams(); - ExecutableScript.Factory executableCombineScript; + MetricAggScripts.CombineScript.Factory compiledCombineScript; Map combineScriptParams; if (combineScript != null) { - executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); + compiledCombineScript = queryShardContext.getScriptService().compile(combineScript, + MetricAggScripts.CombineScript.CONTEXT); combineScriptParams = combineScript.getParams(); } else { - executableCombineScript = p -> null; + compiledCombineScript = (p, a) -> null; combineScriptParams = Collections.emptyMap(); } - return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams, - executableCombineScript, combineScriptParams, reduceScript, + return new ScriptedMetricAggregatorFactory(name, compiledMapScript, mapScriptParams, compiledInitScript, + initScriptParams, compiledCombineScript, combineScriptParams, reduceScript, params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index 04ef595690a33..ad89365af9f37 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -20,10 +20,10 @@ package org.elasticsearch.search.aggregations.metrics.scripted; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Scorer; import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.MetricAggScripts; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; @@ -38,12 +38,12 @@ public class ScriptedMetricAggregator extends MetricsAggregator { - private final SearchScript.LeafFactory mapScript; - private final ExecutableScript combineScript; + private final MetricAggScripts.MapScript.LeafFactory mapScript; + private final MetricAggScripts.CombineScript combineScript; private final Script reduceScript; private Map params; - protected ScriptedMetricAggregator(String name, SearchScript.LeafFactory mapScript, ExecutableScript combineScript, + protected ScriptedMetricAggregator(String name, MetricAggScripts.MapScript.LeafFactory mapScript, MetricAggScripts.CombineScript combineScript, Script reduceScript, Map params, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { @@ -62,13 +62,26 @@ public boolean needsScores() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - final SearchScript leafMapScript = mapScript.newInstance(ctx); + final MetricAggScripts.MapScript leafMapScript = mapScript.newInstance(ctx); return new LeafBucketCollectorBase(sub, leafMapScript) { + private Scorer scorer; + + @Override + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + @Override public void collect(int doc, long bucket) throws IOException { assert bucket == 0 : bucket; + + double _score = 0.0; + if (scorer != null) { + _score = scorer.score(); + } + leafMapScript.setDocument(doc); - leafMapScript.run(); + leafMapScript.execute(_score); } }; } @@ -77,7 +90,7 @@ public void collect(int doc, long bucket) throws IOException { public InternalAggregation buildAggregation(long owningBucketOrdinal) { Object aggregation; if (combineScript != null) { - aggregation = combineScript.run(); + aggregation = combineScript.execute(); CollectionUtils.ensureNoSelfReferences(aggregation); } else { aggregation = params.get("_agg"); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java index 0bc6a614e541f..fca427358b270 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java @@ -19,9 +19,8 @@ package org.elasticsearch.search.aggregations.metrics.scripted; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.MetricAggScripts; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.SearchParseException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; @@ -38,19 +37,19 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory { - private final SearchScript.Factory mapScript; + private final MetricAggScripts.MapScript.Factory mapScript; private final Map mapScriptParams; - private final ExecutableScript.Factory combineScript; + private final MetricAggScripts.CombineScript.Factory combineScript; private final Map combineScriptParams; private final Script reduceScript; private final Map aggParams; private final SearchLookup lookup; - private final ExecutableScript.Factory initScript; + private final MetricAggScripts.InitScript.Factory initScript; private final Map initScriptParams; - public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map mapScriptParams, - ExecutableScript.Factory initScript, Map initScriptParams, - ExecutableScript.Factory combineScript, Map combineScriptParams, + public ScriptedMetricAggregatorFactory(String name, MetricAggScripts.MapScript.Factory mapScript, Map mapScriptParams, + MetricAggScripts.InitScript.Factory initScript, Map initScriptParams, + MetricAggScripts.CombineScript.Factory combineScript, Map combineScriptParams, Script reduceScript, Map aggParams, SearchLookup lookup, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactories, Map metaData) throws IOException { @@ -82,13 +81,18 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu aggParams.put("_agg", new HashMap()); } - final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams)); - final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup); - final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams)); + Object agg = aggParams.get("_agg"); + + final MetricAggScripts.InitScript initScript = this.initScript.newInstance( + mergeParams(aggParams, initScriptParams), agg); + final MetricAggScripts.MapScript.LeafFactory mapScript = this.mapScript.newFactory( + mergeParams(aggParams, mapScriptParams), agg, lookup); + final MetricAggScripts.CombineScript combineScript = this.combineScript.newInstance( + mergeParams(aggParams, combineScriptParams), agg); final Script reduceScript = deepCopyScript(this.reduceScript, context); if (initScript != null) { - initScript.run(); + initScript.execute(); } return new ScriptedMetricAggregator(name, mapScript, combineScript, reduceScript, aggParams, context, parent, diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index 9db5b237a858c..e2b39b751a163 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -68,6 +68,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.notNullValue; @@ -194,12 +195,53 @@ protected Map, Object>> pluginScripts() { return newAggregation; }); + scripts.put("agg.items = new ArrayList()", vars -> + aggContextScript(vars, agg -> ((HashMap) agg).put("items", new ArrayList()))); + + scripts.put("agg.items.add(1)", vars -> + aggContextScript(vars, agg -> { + HashMap aggMap = (HashMap) agg; + List items = (List) aggMap.get("items"); + items.add(1); + })); + + scripts.put("sum context agg values", vars -> { + int sum = 0; + HashMap agg = (HashMap) vars.get("agg"); + List items = (List) agg.get("items"); + + for (Object x : items) { + sum += (Integer)x; + } + + return sum; + }); + + scripts.put("sum context aggs of agg values", vars -> { + Integer sum = 0; + + List aggs = (List) vars.get("aggs"); + for (Object agg : (List) aggs) { + sum += ((Number) agg).intValue(); + } + + return sum; + }); + return scripts; } - @SuppressWarnings("unchecked") static Object aggScript(Map vars, Consumer fn) { - T agg = (T) vars.get("_agg"); + return aggScript(vars, fn, "_agg"); + } + + static Object aggContextScript(Map vars, Consumer fn) { + return aggScript(vars, fn, "agg"); + } + + @SuppressWarnings("unchecked") + private static Object aggScript(Map vars, Consumer fn, String aggVarName) { + T agg = (T) vars.get(aggVarName); fn.accept(agg); return agg; } @@ -1016,4 +1058,37 @@ public void testConflictingAggAndScriptParams() { SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get); assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters")); } + + public void testAggFromContext() { + Script initScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "agg.items = new ArrayList()", Collections.emptyMap()); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "agg.items.add(1)", Collections.emptyMap()); + Script combineScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context agg values", Collections.emptyMap()); + Script reduceScript = + new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context aggs of agg values", + Collections.emptyMap()); + + SearchResponse response = client() + .prepareSearch("idx") + .setQuery(matchAllQuery()) + .addAggregation( + scriptedMetric("scripted") + .initScript(initScript) + .mapScript(mapScript) + .combineScript(combineScript) + .reduceScript(reduceScript)) + .get(); + + Aggregation aggregation = response.getAggregations().get("scripted"); + assertThat(aggregation, notNullValue()); + assertThat(aggregation, instanceOf(ScriptedMetric.class)); + + ScriptedMetric scriptedMetricAggregation = (ScriptedMetric) aggregation; + assertThat(scriptedMetricAggregation.getName(), equalTo("scripted")); + assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); + + assertThat(scriptedMetricAggregation.aggregation(), instanceOf(Integer.class)); + Integer aggResult = (Integer) scriptedMetricAggregation.aggregation(); + long totalAgg = aggResult.longValue(); + assertThat(totalAgg, equalTo(numDocs)); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index 0989b1ce6a3fa..c8318fe6d1ec9 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -100,7 +100,7 @@ public static void initMockScripts() { }); SCRIPTS.put("mapScriptScore", params -> { Map agg = (Map) params.get("_agg"); - ((List) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue()); + ((List) agg.get("collector")).add(((Number) params.get("_score")).doubleValue()); return agg; }); SCRIPTS.put("combineScriptScore", params -> { diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index da3757d77b46e..cae9028cac461 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -21,20 +21,20 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorer; +import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.index.similarity.ScriptedSimilarity.Doc; import org.elasticsearch.index.similarity.ScriptedSimilarity.Field; import org.elasticsearch.index.similarity.ScriptedSimilarity.Query; import org.elasticsearch.index.similarity.ScriptedSimilarity.Term; -import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.search.lookup.LeafSearchLookup; import org.elasticsearch.search.lookup.SearchLookup; import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Function; -import java.util.function.Predicate; import static java.util.Collections.emptyMap; @@ -109,6 +109,18 @@ public String execute() { } else if (context.instanceClazz.equals(SimilarityWeightScript.class)) { SimilarityWeightScript.Factory factory = mockCompiled::createSimilarityWeightScript; return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(MetricAggScripts.InitScript.class)) { + MetricAggScripts.InitScript.Factory factory = mockCompiled::createMetricAggInitScript; + return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(MetricAggScripts.MapScript.class)) { + MetricAggScripts.MapScript.Factory factory = mockCompiled::createMetricAggMapScript; + return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(MetricAggScripts.CombineScript.class)) { + MetricAggScripts.CombineScript.Factory factory = mockCompiled::createMetricAggCombineScript; + return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(MetricAggScripts.ReduceScript.class)) { + MetricAggScripts.ReduceScript.Factory factory = mockCompiled::createMetricAggReduceScript; + return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]"); } @@ -169,6 +181,23 @@ public SimilarityScript createSimilarityScript() { public SimilarityWeightScript createSimilarityWeightScript() { return new MockSimilarityWeightScript(script != null ? script : ctx -> 42d); } + + public MetricAggScripts.InitScript createMetricAggInitScript(Map params, Object agg) { + return new MockMetricAggInitScript(params, agg, script != null ? script : ctx -> 42d); + } + + public MetricAggScripts.MapScript.LeafFactory createMetricAggMapScript(Map params, Object agg, + SearchLookup lookup) { + return new MockMetricAggMapScript(params, agg, lookup, script != null ? script : ctx -> 42d); + } + + public MetricAggScripts.CombineScript createMetricAggCombineScript(Map params, Object agg) { + return new MockMetricAggCombineScript(params, agg, script != null ? script : ctx -> 42d); + } + + public MetricAggScripts.ReduceScript createMetricAggReduceScript(Map params, List aggs) { + return new MockMetricAggReduceScript(params, aggs, script != null ? script : ctx -> 42d); + } } public class MockExecutableScript implements ExecutableScript { @@ -323,6 +352,108 @@ public double execute(Query query, Field field, Term term) throws IOException { } } + public class MockMetricAggInitScript extends MetricAggScripts.InitScript { + private final Function, Object> script; + + MockMetricAggInitScript(Map params, Object agg, + Function, Object> script) { + super(params, agg); + this.script = script; + } + + public void execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("agg", getAgg()); + script.apply(map); + } + } + + public class MockMetricAggMapScript implements MetricAggScripts.MapScript.LeafFactory { + private final Map params; + private final Object agg; + private final SearchLookup lookup; + private final Function, Object> script; + + MockMetricAggMapScript(Map params, Object agg, SearchLookup lookup, + Function, Object> script) { + this.params = params; + this.agg = agg; + this.lookup = lookup; + this.script = script; + } + + @Override + public MetricAggScripts.MapScript newInstance(LeafReaderContext context) { + return new MetricAggScripts.MapScript(params, agg, lookup, context) { + @Override + public void execute(double _score) { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("agg", getAgg()); + map.put("doc", getDoc()); + map.put("_score", _score); + + script.apply(map); + } + }; + } + } + + public class MockMetricAggCombineScript extends MetricAggScripts.CombineScript { + private final Function, Object> script; + + MockMetricAggCombineScript(Map params, Object agg, + Function, Object> script) { + super(params, agg); + this.script = script; + } + + public Object execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("agg", getAgg()); + return script.apply(map); + } + } + + public class MockMetricAggReduceScript extends MetricAggScripts.ReduceScript { + private final Function, Object> script; + + MockMetricAggReduceScript(Map params, List aggs, + Function, Object> script) { + super(params, aggs); + this.script = script; + } + + public Object execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("aggs", getAggs()); + return script.apply(map); + } + } + public static Script mockInlineScript(final String script) { return new Script(ScriptType.INLINE, "mock", script, emptyMap()); } From 2beee0739570a9a63673c749691bbddc4865fa6e Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Tue, 8 May 2018 22:50:04 -0700 Subject: [PATCH 2/9] Rename new script context container class and add clarifying comments to remaining references to params._agg(s) --- ...va => ScriptedMetricAggContextsTests.java} | 38 +++++++++--------- .../elasticsearch/script/ScriptModule.java | 8 ++-- ...ts.java => ScriptedMetricAggContexts.java} | 2 +- .../scripted/InternalScriptedMetric.java | 10 +++-- .../ScriptedMetricAggregationBuilder.java | 14 +++---- .../scripted/ScriptedMetricAggregator.java | 22 +++++------ .../ScriptedMetricAggregatorFactory.java | 30 +++++++------- .../script/MockScriptEngine.java | 39 +++++++++---------- 8 files changed, 84 insertions(+), 79 deletions(-) rename modules/lang-painless/src/test/java/org/elasticsearch/painless/{MetricAggScriptsTests.java => ScriptedMetricAggContextsTests.java} (60%) rename server/src/main/java/org/elasticsearch/script/{MetricAggScripts.java => ScriptedMetricAggContexts.java} (99%) diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java similarity index 60% rename from modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java rename to modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java index 02d8250a664f8..48d187df2bbd2 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/MetricAggScriptsTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java @@ -20,7 +20,7 @@ package org.elasticsearch.painless; import org.elasticsearch.painless.spi.Whitelist; -import org.elasticsearch.script.MetricAggScripts; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.ScriptContext; import java.util.ArrayList; @@ -29,27 +29,27 @@ import java.util.List; import java.util.Map; -public class MetricAggScriptsTests extends ScriptTestCase { +public class ScriptedMetricAggContextsTests extends ScriptTestCase { @Override protected Map, List> scriptContexts() { Map, List> contexts = new HashMap<>(); - contexts.put(MetricAggScripts.InitScript.CONTEXT, Whitelist.BASE_WHITELISTS); - contexts.put(MetricAggScripts.MapScript.CONTEXT, Whitelist.BASE_WHITELISTS); - contexts.put(MetricAggScripts.CombineScript.CONTEXT, Whitelist.BASE_WHITELISTS); - contexts.put(MetricAggScripts.ReduceScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.InitScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.MapScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.CombineScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.ReduceScript.CONTEXT, Whitelist.BASE_WHITELISTS); return contexts; } public void testInitBasic() { - MetricAggScripts.InitScript.Factory factory = scriptEngine.compile("test", - "agg.testField = params.initialVal", MetricAggScripts.InitScript.CONTEXT, Collections.emptyMap()); + ScriptedMetricAggContexts.InitScript.Factory factory = scriptEngine.compile("test", + "agg.testField = params.initialVal", ScriptedMetricAggContexts.InitScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); Map agg = new HashMap<>(); params.put("initialVal", 10); - MetricAggScripts.InitScript script = factory.newInstance(params, agg); + ScriptedMetricAggContexts.InitScript script = factory.newInstance(params, agg); script.execute(); assert(agg.containsKey("testField")); @@ -57,15 +57,15 @@ public void testInitBasic() { } public void testMapBasic() { - MetricAggScripts.MapScript.Factory factory = scriptEngine.compile("test", - "agg.testField = 2*_score", MetricAggScripts.MapScript.CONTEXT, Collections.emptyMap()); + ScriptedMetricAggContexts.MapScript.Factory factory = scriptEngine.compile("test", + "agg.testField = 2*_score", ScriptedMetricAggContexts.MapScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); Map agg = new HashMap<>(); double _score = 0.5; - MetricAggScripts.MapScript.LeafFactory leafFactory = factory.newFactory(params, agg, null); - MetricAggScripts.MapScript script = leafFactory.newInstance(null); + ScriptedMetricAggContexts.MapScript.LeafFactory leafFactory = factory.newFactory(params, agg, null); + ScriptedMetricAggContexts.MapScript script = leafFactory.newInstance(null); script.execute(_score); @@ -74,8 +74,8 @@ public void testMapBasic() { } public void testCombineBasic() { - MetricAggScripts.CombineScript.Factory factory = scriptEngine.compile("test", - "agg.testField = params.initialVal; return agg.testField + params.inc", MetricAggScripts.CombineScript.CONTEXT, + ScriptedMetricAggContexts.CombineScript.Factory factory = scriptEngine.compile("test", + "agg.testField = params.initialVal; return agg.testField + params.inc", ScriptedMetricAggContexts.CombineScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); @@ -84,7 +84,7 @@ public void testCombineBasic() { params.put("initialVal", 10); params.put("inc", 2); - MetricAggScripts.CombineScript script = factory.newInstance(params, agg); + ScriptedMetricAggContexts.CombineScript script = factory.newInstance(params, agg); Object res = script.execute(); assert(agg.containsKey("testField")); @@ -93,8 +93,8 @@ public void testCombineBasic() { } public void testReduceBasic() { - MetricAggScripts.ReduceScript.Factory factory = scriptEngine.compile("test", - "aggs[0].testField + aggs[1].testField", MetricAggScripts.ReduceScript.CONTEXT, Collections.emptyMap()); + ScriptedMetricAggContexts.ReduceScript.Factory factory = scriptEngine.compile("test", + "aggs[0].testField + aggs[1].testField", ScriptedMetricAggContexts.ReduceScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); List aggs = new ArrayList<>(); @@ -106,7 +106,7 @@ public void testReduceBasic() { aggs.add(agg1); aggs.add(agg2); - MetricAggScripts.ReduceScript script = factory.newInstance(params, aggs); + ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, aggs); Object res = script.execute(); assertEquals(3, res); } diff --git a/server/src/main/java/org/elasticsearch/script/ScriptModule.java b/server/src/main/java/org/elasticsearch/script/ScriptModule.java index 903dcf5831a28..b1cd4689df912 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptModule.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptModule.java @@ -49,10 +49,10 @@ public class ScriptModule { SimilarityScript.CONTEXT, SimilarityWeightScript.CONTEXT, TemplateScript.CONTEXT, - MetricAggScripts.InitScript.CONTEXT, - MetricAggScripts.MapScript.CONTEXT, - MetricAggScripts.CombineScript.CONTEXT, - MetricAggScripts.ReduceScript.CONTEXT + ScriptedMetricAggContexts.InitScript.CONTEXT, + ScriptedMetricAggContexts.MapScript.CONTEXT, + ScriptedMetricAggContexts.CombineScript.CONTEXT, + ScriptedMetricAggContexts.ReduceScript.CONTEXT ).collect(Collectors.toMap(c -> c.name, Function.identity())); } diff --git a/server/src/main/java/org/elasticsearch/script/MetricAggScripts.java b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java similarity index 99% rename from server/src/main/java/org/elasticsearch/script/MetricAggScripts.java rename to server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java index 4d44b6f78e250..910d320b8b03a 100644 --- a/server/src/main/java/org/elasticsearch/script/MetricAggScripts.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java @@ -27,7 +27,7 @@ import java.util.List; import java.util.Map; -public class MetricAggScripts { +public class ScriptedMetricAggContexts { private abstract static class ParamsAndAggBase { private final Map params; private final Object agg; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java index b151e96a9dafe..8dc9ca956124d 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java @@ -22,7 +22,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.script.MetricAggScripts; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; @@ -93,11 +93,13 @@ public InternalAggregation doReduce(List aggregations, Redu if (firstAggregation.reduceScript.getParams() != null) { params.putAll(firstAggregation.reduceScript.getParams()); } + + // Add _aggs to params map for backwards compatibility (redundant with a context variable on the ReduceScript created below). params.put("_aggs", aggregationObjects); - MetricAggScripts.ReduceScript.Factory factory = reduceContext.scriptService().compile( - firstAggregation.reduceScript, MetricAggScripts.ReduceScript.CONTEXT); - MetricAggScripts.ReduceScript script = factory.newInstance(params, aggregationObjects); + ScriptedMetricAggContexts.ReduceScript.Factory factory = reduceContext.scriptService().compile( + firstAggregation.reduceScript, ScriptedMetricAggContexts.ReduceScript.CONTEXT); + ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, aggregationObjects); aggregation = Collections.singletonList(script.execute()); } else if (reduceContext.isFinalReduce()) { aggregation = Collections.singletonList(aggregationObjects); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java index ae82c8b8b9f77..8b6d834184d73 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java @@ -26,7 +26,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryShardContext; -import org.elasticsearch.script.MetricAggScripts; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; @@ -201,25 +201,25 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega // Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have // access to them for the scripts it's given precompiled. - MetricAggScripts.InitScript.Factory compiledInitScript; + ScriptedMetricAggContexts.InitScript.Factory compiledInitScript; Map initScriptParams; if (initScript != null) { - compiledInitScript = queryShardContext.getScriptService().compile(initScript, MetricAggScripts.InitScript.CONTEXT); + compiledInitScript = queryShardContext.getScriptService().compile(initScript, ScriptedMetricAggContexts.InitScript.CONTEXT); initScriptParams = initScript.getParams(); } else { compiledInitScript = (p, a) -> null; initScriptParams = Collections.emptyMap(); } - MetricAggScripts.MapScript.Factory compiledMapScript = queryShardContext.getScriptService().compile(mapScript, - MetricAggScripts.MapScript.CONTEXT); + ScriptedMetricAggContexts.MapScript.Factory compiledMapScript = queryShardContext.getScriptService().compile(mapScript, + ScriptedMetricAggContexts.MapScript.CONTEXT); Map mapScriptParams = mapScript.getParams(); - MetricAggScripts.CombineScript.Factory compiledCombineScript; + ScriptedMetricAggContexts.CombineScript.Factory compiledCombineScript; Map combineScriptParams; if (combineScript != null) { compiledCombineScript = queryShardContext.getScriptService().compile(combineScript, - MetricAggScripts.CombineScript.CONTEXT); + ScriptedMetricAggContexts.CombineScript.CONTEXT); combineScriptParams = combineScript.getParams(); } else { compiledCombineScript = (p, a) -> null; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index ad89365af9f37..c7b7990c8d237 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -22,7 +22,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorer; import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.script.MetricAggScripts; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; @@ -38,17 +38,17 @@ public class ScriptedMetricAggregator extends MetricsAggregator { - private final MetricAggScripts.MapScript.LeafFactory mapScript; - private final MetricAggScripts.CombineScript combineScript; + private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; + private final ScriptedMetricAggContexts.CombineScript combineScript; private final Script reduceScript; - private Map params; + private Object agg; - protected ScriptedMetricAggregator(String name, MetricAggScripts.MapScript.LeafFactory mapScript, MetricAggScripts.CombineScript combineScript, - Script reduceScript, - Map params, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) - throws IOException { + protected ScriptedMetricAggregator(String name, ScriptedMetricAggContexts.MapScript.LeafFactory mapScript, ScriptedMetricAggContexts.CombineScript combineScript, + Script reduceScript, Object agg, SearchContext context, Aggregator parent, + List pipelineAggregators, Map metaData) + throws IOException { super(name, context, parent, pipelineAggregators, metaData); - this.params = params; + this.agg = agg; this.mapScript = mapScript; this.combineScript = combineScript; this.reduceScript = reduceScript; @@ -62,7 +62,7 @@ public boolean needsScores() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - final MetricAggScripts.MapScript leafMapScript = mapScript.newInstance(ctx); + final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx); return new LeafBucketCollectorBase(sub, leafMapScript) { private Scorer scorer; @@ -93,7 +93,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { aggregation = combineScript.execute(); CollectionUtils.ensureNoSelfReferences(aggregation); } else { - aggregation = params.get("_agg"); + aggregation = agg; } return new InternalScriptedMetric(name, aggregation, reduceScript, pipelineAggregators(), metaData()); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java index fca427358b270..57dab2e7b5a0a 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java @@ -19,7 +19,7 @@ package org.elasticsearch.search.aggregations.metrics.scripted; -import org.elasticsearch.script.MetricAggScripts; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; import org.elasticsearch.search.SearchParseException; import org.elasticsearch.search.aggregations.Aggregator; @@ -37,20 +37,21 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory { - private final MetricAggScripts.MapScript.Factory mapScript; + private final ScriptedMetricAggContexts.MapScript.Factory mapScript; private final Map mapScriptParams; - private final MetricAggScripts.CombineScript.Factory combineScript; + private final ScriptedMetricAggContexts.CombineScript.Factory combineScript; private final Map combineScriptParams; private final Script reduceScript; private final Map aggParams; private final SearchLookup lookup; - private final MetricAggScripts.InitScript.Factory initScript; + private final ScriptedMetricAggContexts.InitScript.Factory initScript; private final Map initScriptParams; - public ScriptedMetricAggregatorFactory(String name, MetricAggScripts.MapScript.Factory mapScript, Map mapScriptParams, - MetricAggScripts.InitScript.Factory initScript, Map initScriptParams, - MetricAggScripts.CombineScript.Factory combineScript, Map combineScriptParams, - Script reduceScript, Map aggParams, + public ScriptedMetricAggregatorFactory(String name, + ScriptedMetricAggContexts.MapScript.Factory mapScript, Map mapScriptParams, + ScriptedMetricAggContexts.InitScript.Factory initScript, Map initScriptParams, + ScriptedMetricAggContexts.CombineScript.Factory combineScript, + Map combineScriptParams, Script reduceScript, Map aggParams, SearchLookup lookup, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactories, Map metaData) throws IOException { super(name, context, parent, subFactories, metaData); @@ -77,17 +78,20 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu } else { aggParams = new HashMap<>(); } + + // Add _agg to params map for backwards compatibility (redundant with context variables on the scripts created below). + // When this is removed, agg (as passed to ScriptedMetricAggregator) can be changed to Map, since + // it won't be possible to completely replace it with another type as is possible when it's an entry in params. if (aggParams.containsKey("_agg") == false) { aggParams.put("_agg", new HashMap()); } - Object agg = aggParams.get("_agg"); - final MetricAggScripts.InitScript initScript = this.initScript.newInstance( + final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( mergeParams(aggParams, initScriptParams), agg); - final MetricAggScripts.MapScript.LeafFactory mapScript = this.mapScript.newFactory( + final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory( mergeParams(aggParams, mapScriptParams), agg, lookup); - final MetricAggScripts.CombineScript combineScript = this.combineScript.newInstance( + final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance( mergeParams(aggParams, combineScriptParams), agg); final Script reduceScript = deepCopyScript(this.reduceScript, context); @@ -95,7 +99,7 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu initScript.execute(); } return new ScriptedMetricAggregator(name, mapScript, - combineScript, reduceScript, aggParams, context, parent, + combineScript, reduceScript, agg, context, parent, pipelineAggregators, metaData); } diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index cae9028cac461..c03b5a6cc68fd 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -21,7 +21,6 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorer; -import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.index.similarity.ScriptedSimilarity.Doc; import org.elasticsearch.index.similarity.ScriptedSimilarity.Field; import org.elasticsearch.index.similarity.ScriptedSimilarity.Query; @@ -109,17 +108,17 @@ public String execute() { } else if (context.instanceClazz.equals(SimilarityWeightScript.class)) { SimilarityWeightScript.Factory factory = mockCompiled::createSimilarityWeightScript; return context.factoryClazz.cast(factory); - } else if (context.instanceClazz.equals(MetricAggScripts.InitScript.class)) { - MetricAggScripts.InitScript.Factory factory = mockCompiled::createMetricAggInitScript; + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.InitScript.class)) { + ScriptedMetricAggContexts.InitScript.Factory factory = mockCompiled::createMetricAggInitScript; return context.factoryClazz.cast(factory); - } else if (context.instanceClazz.equals(MetricAggScripts.MapScript.class)) { - MetricAggScripts.MapScript.Factory factory = mockCompiled::createMetricAggMapScript; + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.MapScript.class)) { + ScriptedMetricAggContexts.MapScript.Factory factory = mockCompiled::createMetricAggMapScript; return context.factoryClazz.cast(factory); - } else if (context.instanceClazz.equals(MetricAggScripts.CombineScript.class)) { - MetricAggScripts.CombineScript.Factory factory = mockCompiled::createMetricAggCombineScript; + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.CombineScript.class)) { + ScriptedMetricAggContexts.CombineScript.Factory factory = mockCompiled::createMetricAggCombineScript; return context.factoryClazz.cast(factory); - } else if (context.instanceClazz.equals(MetricAggScripts.ReduceScript.class)) { - MetricAggScripts.ReduceScript.Factory factory = mockCompiled::createMetricAggReduceScript; + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.ReduceScript.class)) { + ScriptedMetricAggContexts.ReduceScript.Factory factory = mockCompiled::createMetricAggReduceScript; return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]"); @@ -182,20 +181,20 @@ public SimilarityWeightScript createSimilarityWeightScript() { return new MockSimilarityWeightScript(script != null ? script : ctx -> 42d); } - public MetricAggScripts.InitScript createMetricAggInitScript(Map params, Object agg) { + public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map params, Object agg) { return new MockMetricAggInitScript(params, agg, script != null ? script : ctx -> 42d); } - public MetricAggScripts.MapScript.LeafFactory createMetricAggMapScript(Map params, Object agg, - SearchLookup lookup) { + public ScriptedMetricAggContexts.MapScript.LeafFactory createMetricAggMapScript(Map params, Object agg, + SearchLookup lookup) { return new MockMetricAggMapScript(params, agg, lookup, script != null ? script : ctx -> 42d); } - public MetricAggScripts.CombineScript createMetricAggCombineScript(Map params, Object agg) { + public ScriptedMetricAggContexts.CombineScript createMetricAggCombineScript(Map params, Object agg) { return new MockMetricAggCombineScript(params, agg, script != null ? script : ctx -> 42d); } - public MetricAggScripts.ReduceScript createMetricAggReduceScript(Map params, List aggs) { + public ScriptedMetricAggContexts.ReduceScript createMetricAggReduceScript(Map params, List aggs) { return new MockMetricAggReduceScript(params, aggs, script != null ? script : ctx -> 42d); } } @@ -352,7 +351,7 @@ public double execute(Query query, Field field, Term term) throws IOException { } } - public class MockMetricAggInitScript extends MetricAggScripts.InitScript { + public class MockMetricAggInitScript extends ScriptedMetricAggContexts.InitScript { private final Function, Object> script; MockMetricAggInitScript(Map params, Object agg, @@ -374,7 +373,7 @@ public void execute() { } } - public class MockMetricAggMapScript implements MetricAggScripts.MapScript.LeafFactory { + public class MockMetricAggMapScript implements ScriptedMetricAggContexts.MapScript.LeafFactory { private final Map params; private final Object agg; private final SearchLookup lookup; @@ -389,8 +388,8 @@ public class MockMetricAggMapScript implements MetricAggScripts.MapScript.LeafFa } @Override - public MetricAggScripts.MapScript newInstance(LeafReaderContext context) { - return new MetricAggScripts.MapScript(params, agg, lookup, context) { + public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext context) { + return new ScriptedMetricAggContexts.MapScript(params, agg, lookup, context) { @Override public void execute(double _score) { Map map = new HashMap<>(); @@ -410,7 +409,7 @@ public void execute(double _score) { } } - public class MockMetricAggCombineScript extends MetricAggScripts.CombineScript { + public class MockMetricAggCombineScript extends ScriptedMetricAggContexts.CombineScript { private final Function, Object> script; MockMetricAggCombineScript(Map params, Object agg, @@ -432,7 +431,7 @@ public Object execute() { } } - public class MockMetricAggReduceScript extends MetricAggScripts.ReduceScript { + public class MockMetricAggReduceScript extends ScriptedMetricAggContexts.ReduceScript { private final Function, Object> script; MockMetricAggReduceScript(Map params, List aggs, From 3d3c914b22fda37447264c4d963a9f5941749d95 Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Sat, 19 May 2018 21:00:05 -0700 Subject: [PATCH 3/9] Misc cleanup: make mock metric agg script inner classes static --- .../java/org/elasticsearch/script/MockScriptEngine.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index f9786e0acd3d5..8a29a957c8e42 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -360,7 +360,7 @@ public double execute(Query query, Field field, Term term) throws IOException { } } - public class MockMetricAggInitScript extends ScriptedMetricAggContexts.InitScript { + public static class MockMetricAggInitScript extends ScriptedMetricAggContexts.InitScript { private final Function, Object> script; MockMetricAggInitScript(Map params, Object agg, @@ -382,7 +382,7 @@ public void execute() { } } - public class MockMetricAggMapScript implements ScriptedMetricAggContexts.MapScript.LeafFactory { + public static class MockMetricAggMapScript implements ScriptedMetricAggContexts.MapScript.LeafFactory { private final Map params; private final Object agg; private final SearchLookup lookup; @@ -418,7 +418,7 @@ public void execute(double _score) { } } - public class MockMetricAggCombineScript extends ScriptedMetricAggContexts.CombineScript { + public static class MockMetricAggCombineScript extends ScriptedMetricAggContexts.CombineScript { private final Function, Object> script; MockMetricAggCombineScript(Map params, Object agg, @@ -440,7 +440,7 @@ public Object execute() { } } - public class MockMetricAggReduceScript extends ScriptedMetricAggContexts.ReduceScript { + public static class MockMetricAggReduceScript extends ScriptedMetricAggContexts.ReduceScript { private final Function, Object> script; MockMetricAggReduceScript(Map params, List aggs, From c16e7000eb3be82729321d0f66ecedd0d48499fa Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Sat, 19 May 2018 22:06:52 -0700 Subject: [PATCH 4/9] Move _score to an accessor rather than an arg for scripted metric agg scripts This causes the score to be evaluated only when it's used. --- .../ScriptedMetricAggContextsTests.java | 17 +++++++++++-- .../script/ScriptedMetricAggContexts.java | 25 +++++++++++++++++-- .../scripted/ScriptedMetricAggregator.java | 11 ++------ .../script/MockScriptEngine.java | 4 +-- 4 files changed, 42 insertions(+), 15 deletions(-) diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java index 48d187df2bbd2..49fe2d10f9322 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptedMetricAggContextsTests.java @@ -19,6 +19,8 @@ package org.elasticsearch.painless; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; import org.elasticsearch.painless.spi.Whitelist; import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.ScriptContext; @@ -62,12 +64,23 @@ public void testMapBasic() { Map params = new HashMap<>(); Map agg = new HashMap<>(); - double _score = 0.5; + + Scorer scorer = new Scorer(null) { + @Override + public int docID() { return 0; } + + @Override + public float score() { return 0.5f; } + + @Override + public DocIdSetIterator iterator() { return null; } + }; ScriptedMetricAggContexts.MapScript.LeafFactory leafFactory = factory.newFactory(params, agg, null); ScriptedMetricAggContexts.MapScript script = leafFactory.newInstance(null); - script.execute(_score); + script.setScorer(scorer); + script.execute(); assert(agg.containsKey("testField")); assertEquals(1.0, agg.get("testField")); diff --git a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java index 910d320b8b03a..ba45abcc97173 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java @@ -20,10 +20,13 @@ package org.elasticsearch.script; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Scorer; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.search.lookup.LeafSearchLookup; import org.elasticsearch.search.lookup.SearchLookup; +import java.io.IOException; import java.util.List; import java.util.Map; @@ -63,6 +66,7 @@ public interface Factory { public abstract static class MapScript extends ParamsAndAggBase { private final LeafSearchLookup leafLookup; + private Scorer scorer; public MapScript(Map params, Object agg, SearchLookup lookup, LeafReaderContext leafContext) { super(params, agg); @@ -82,7 +86,24 @@ public void setDocument(int docId) { } } - public abstract void execute(double _score); + public void setScorer(Scorer scorer) { + this.scorer = scorer; + } + + // get_score() is named this way so that it's picked up by Painless as '_score' + public double get_score() { + if (scorer == null) { + return 0.0; + } + + try { + return scorer.score(); + } catch (IOException e) { + throw new ElasticsearchException("Couldn't look up score", e); + } + } + + public abstract void execute(); public interface LeafFactory { MapScript newInstance(LeafReaderContext ctx); @@ -92,7 +113,7 @@ public interface Factory { LeafFactory newFactory(Map params, Object agg, SearchLookup lookup); } - public static String[] PARAMETERS = new String[] {"_score"}; + public static String[] PARAMETERS = new String[] {}; public static ScriptContext CONTEXT = new ScriptContext<>("aggs_map", Factory.class); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index c7b7990c8d237..2a5cf0ecd8320 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -64,24 +64,17 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx); return new LeafBucketCollectorBase(sub, leafMapScript) { - private Scorer scorer; - @Override public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; + leafMapScript.setScorer(scorer); } @Override public void collect(int doc, long bucket) throws IOException { assert bucket == 0 : bucket; - double _score = 0.0; - if (scorer != null) { - _score = scorer.score(); - } - leafMapScript.setDocument(doc); - leafMapScript.execute(_score); + leafMapScript.execute(); } }; } diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index 8a29a957c8e42..3b90cf655d48b 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -400,7 +400,7 @@ public static class MockMetricAggMapScript implements ScriptedMetricAggContexts. public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext context) { return new ScriptedMetricAggContexts.MapScript(params, agg, lookup, context) { @Override - public void execute(double _score) { + public void execute() { Map map = new HashMap<>(); if (getParams() != null) { @@ -410,7 +410,7 @@ public void execute(double _score) { map.put("agg", getAgg()); map.put("doc", getDoc()); - map.put("_score", _score); + map.put("_score", get_score()); script.apply(map); } From d28c019a084cfef3525e23a4a9584cf9db8fe296 Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Mon, 28 May 2018 20:17:17 -0700 Subject: [PATCH 5/9] Documentation changes for params._agg -> agg --- .../scripted-metric-aggregation.asciidoc | 18 +++--- .../scripted-metric-aggregation.asciidoc | 64 ++++++++----------- 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc index b23a683b05610..0c4cb0f540eb7 100644 --- a/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -13,8 +13,8 @@ Here is an example on how to create the aggregation request: -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("params._agg.heights = []")) - .mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")); + .initScript(new Script("agg.heights = []")) + .mapScript(new Script("agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")); -------------------------------------------------- You can also specify a `combine` script which will be executed on each shard: @@ -23,9 +23,9 @@ You can also specify a `combine` script which will be executed on each shard: -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("params._agg.heights = []")) - .mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) - .combineScript(new Script("double heights_sum = 0.0; for (t in params._agg.heights) { heights_sum += t } return heights_sum")); + .initScript(new Script("agg.heights = []")) + .mapScript(new Script("agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) + .combineScript(new Script("double heights_sum = 0.0; for (t in agg.heights) { heights_sum += t } return heights_sum")); -------------------------------------------------- You can also specify a `reduce` script which will be executed on the node which gets the request: @@ -34,10 +34,10 @@ You can also specify a `reduce` script which will be executed on the node which -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("params._agg.heights = []")) - .mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) - .combineScript(new Script("double heights_sum = 0.0; for (t in params._agg.heights) { heights_sum += t } return heights_sum")) - .reduceScript(new Script("double heights_sum = 0.0; for (a in params._aggs) { heights_sum += a } return heights_sum")); + .initScript(new Script("agg.heights = []")) + .mapScript(new Script("agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) + .combineScript(new Script("double heights_sum = 0.0; for (t in agg.heights) { heights_sum += t } return heights_sum")) + .reduceScript(new Script("double heights_sum = 0.0; for (a in aggs) { heights_sum += a } return heights_sum")); -------------------------------------------------- diff --git a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc index daa86969e4556..47257aa11a77b 100644 --- a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -17,10 +17,10 @@ POST ledger/_search?size=0 "aggs": { "profit": { "scripted_metric": { - "init_script" : "params._agg.transactions = []", - "map_script" : "params._agg.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1> - "combine_script" : "double profit = 0; for (t in params._agg.transactions) { profit += t } return profit", - "reduce_script" : "double profit = 0; for (a in params._aggs) { profit += a } return profit" + "init_script" : "agg.transactions = []", + "map_script" : "agg.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1> + "combine_script" : "double profit = 0; for (t in agg.transactions) { profit += t } return profit", + "reduce_script" : "double profit = 0; for (a in aggs) { profit += a } return profit" } } } @@ -69,8 +69,7 @@ POST ledger/_search?size=0 "id": "my_combine_script" }, "params": { - "field": "amount", <1> - "_agg": {} <2> + "field": "amount" <1> }, "reduce_script" : { "id": "my_reduce_script" @@ -84,8 +83,7 @@ POST ledger/_search?size=0 // TEST[setup:ledger,stored_scripted_metric_script] <1> script parameters for `init`, `map` and `combine` scripts must be specified -in a global `params` object so that it can be share between the scripts. -<2> if you specify script parameters then you must specify `"_agg": {}`. +in a global `params` object so that it can be shared between the scripts. //// Verify this response as well but in a hidden block. @@ -110,7 +108,7 @@ For more details on specifying scripts see < Date: Mon, 28 May 2018 20:33:08 -0700 Subject: [PATCH 6/9] Migration doc addition for scripted metric aggs _agg object change --- .../migration/migrate_7_0/aggregations.asciidoc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/reference/migration/migrate_7_0/aggregations.asciidoc b/docs/reference/migration/migrate_7_0/aggregations.asciidoc index 5241ba4ccc76c..b768748b1641e 100644 --- a/docs/reference/migration/migrate_7_0/aggregations.asciidoc +++ b/docs/reference/migration/migrate_7_0/aggregations.asciidoc @@ -9,4 +9,12 @@ These `execution_hint` are removed and should be replaced by `global_ordinals`. The dynamic cluster setting named `search.max_buckets` now defaults to 10,000 (instead of unlimited in the previous version). -Requests that try to return more than the limit will fail with an exception. \ No newline at end of file +Requests that try to return more than the limit will fail with an exception. + +==== Replaced `params._agg` with `agg` context variable in scripted metric aggregations + +The object used to share aggregation state between the scripts in a Scripted Metric +Aggregation is now a variable called `agg` available in the script context, rather than +being provided via the `params` object as `params._agg`. + +The old `params._agg` variable is still available as well. \ No newline at end of file From 793e47b514b6d58d9c3e0c9b84031b0b7f4df1ac Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Thu, 31 May 2018 22:57:53 -0700 Subject: [PATCH 7/9] Rename "agg" Scripted Metric Aggregation script context variable to "state" --- .../scripted-metric-aggregation.asciidoc | 18 +++--- .../scripted-metric-aggregation.asciidoc | 36 ++++++------ .../migrate_7_0/aggregations.asciidoc | 4 +- .../ScriptedMetricAggContextsTests.java | 46 +++++++-------- .../script/ScriptedMetricAggContexts.java | 40 ++++++------- .../scripted/ScriptedMetricAggregator.java | 8 +-- .../ScriptedMetricAggregatorFactory.java | 12 ++-- .../metrics/ScriptedMetricIT.java | 45 ++++++++------- .../script/MockScriptEngine.java | 56 +++++++++---------- 9 files changed, 132 insertions(+), 133 deletions(-) diff --git a/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc index 0c4cb0f540eb7..5b68fa7be451f 100644 --- a/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -13,8 +13,8 @@ Here is an example on how to create the aggregation request: -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("agg.heights = []")) - .mapScript(new Script("agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")); + .initScript(new Script("state.heights = []")) + .mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")); -------------------------------------------------- You can also specify a `combine` script which will be executed on each shard: @@ -23,9 +23,9 @@ You can also specify a `combine` script which will be executed on each shard: -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("agg.heights = []")) - .mapScript(new Script("agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) - .combineScript(new Script("double heights_sum = 0.0; for (t in agg.heights) { heights_sum += t } return heights_sum")); + .initScript(new Script("state.heights = []")) + .mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) + .combineScript(new Script("double heights_sum = 0.0; for (t in state.heights) { heights_sum += t } return heights_sum")); -------------------------------------------------- You can also specify a `reduce` script which will be executed on the node which gets the request: @@ -34,10 +34,10 @@ You can also specify a `reduce` script which will be executed on the node which -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("agg.heights = []")) - .mapScript(new Script("agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) - .combineScript(new Script("double heights_sum = 0.0; for (t in agg.heights) { heights_sum += t } return heights_sum")) - .reduceScript(new Script("double heights_sum = 0.0; for (a in aggs) { heights_sum += a } return heights_sum")); + .initScript(new Script("state.heights = []")) + .mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) + .combineScript(new Script("double heights_sum = 0.0; for (t in state.heights) { heights_sum += t } return heights_sum")) + .reduceScript(new Script("double heights_sum = 0.0; for (a in states) { heights_sum += a } return heights_sum")); -------------------------------------------------- diff --git a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc index 47257aa11a77b..2a84779a271b7 100644 --- a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -17,10 +17,10 @@ POST ledger/_search?size=0 "aggs": { "profit": { "scripted_metric": { - "init_script" : "agg.transactions = []", - "map_script" : "agg.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1> - "combine_script" : "double profit = 0; for (t in agg.transactions) { profit += t } return profit", - "reduce_script" : "double profit = 0; for (a in aggs) { profit += a } return profit" + "init_script" : "state.transactions = []", + "map_script" : "state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1> + "combine_script" : "double profit = 0; for (t in state.transactions) { profit += t } return profit", + "reduce_script" : "double profit = 0; for (a in states) { profit += a } return profit" } } } @@ -108,7 +108,7 @@ For more details on specifying scripts see <, List> scriptContexts() { public void testInitBasic() { ScriptedMetricAggContexts.InitScript.Factory factory = scriptEngine.compile("test", - "agg.testField = params.initialVal", ScriptedMetricAggContexts.InitScript.CONTEXT, Collections.emptyMap()); + "state.testField = params.initialVal", ScriptedMetricAggContexts.InitScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); - Map agg = new HashMap<>(); + Map state = new HashMap<>(); params.put("initialVal", 10); - ScriptedMetricAggContexts.InitScript script = factory.newInstance(params, agg); + ScriptedMetricAggContexts.InitScript script = factory.newInstance(params, state); script.execute(); - assert(agg.containsKey("testField")); - assertEquals(10, agg.get("testField")); + assert(state.containsKey("testField")); + assertEquals(10, state.get("testField")); } public void testMapBasic() { ScriptedMetricAggContexts.MapScript.Factory factory = scriptEngine.compile("test", - "agg.testField = 2*_score", ScriptedMetricAggContexts.MapScript.CONTEXT, Collections.emptyMap()); + "state.testField = 2*_score", ScriptedMetricAggContexts.MapScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); - Map agg = new HashMap<>(); + Map state = new HashMap<>(); Scorer scorer = new Scorer(null) { @Override @@ -76,50 +76,50 @@ public void testMapBasic() { public DocIdSetIterator iterator() { return null; } }; - ScriptedMetricAggContexts.MapScript.LeafFactory leafFactory = factory.newFactory(params, agg, null); + ScriptedMetricAggContexts.MapScript.LeafFactory leafFactory = factory.newFactory(params, state, null); ScriptedMetricAggContexts.MapScript script = leafFactory.newInstance(null); script.setScorer(scorer); script.execute(); - assert(agg.containsKey("testField")); - assertEquals(1.0, agg.get("testField")); + assert(state.containsKey("testField")); + assertEquals(1.0, state.get("testField")); } public void testCombineBasic() { ScriptedMetricAggContexts.CombineScript.Factory factory = scriptEngine.compile("test", - "agg.testField = params.initialVal; return agg.testField + params.inc", ScriptedMetricAggContexts.CombineScript.CONTEXT, + "state.testField = params.initialVal; return state.testField + params.inc", ScriptedMetricAggContexts.CombineScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); - Map agg = new HashMap<>(); + Map state = new HashMap<>(); params.put("initialVal", 10); params.put("inc", 2); - ScriptedMetricAggContexts.CombineScript script = factory.newInstance(params, agg); + ScriptedMetricAggContexts.CombineScript script = factory.newInstance(params, state); Object res = script.execute(); - assert(agg.containsKey("testField")); - assertEquals(10, agg.get("testField")); + assert(state.containsKey("testField")); + assertEquals(10, state.get("testField")); assertEquals(12, res); } public void testReduceBasic() { ScriptedMetricAggContexts.ReduceScript.Factory factory = scriptEngine.compile("test", - "aggs[0].testField + aggs[1].testField", ScriptedMetricAggContexts.ReduceScript.CONTEXT, Collections.emptyMap()); + "states[0].testField + states[1].testField", ScriptedMetricAggContexts.ReduceScript.CONTEXT, Collections.emptyMap()); Map params = new HashMap<>(); - List aggs = new ArrayList<>(); + List states = new ArrayList<>(); - Map agg1 = new HashMap<>(), agg2 = new HashMap<>(); - agg1.put("testField", 1); - agg2.put("testField", 2); + Map state1 = new HashMap<>(), state2 = new HashMap<>(); + state1.put("testField", 1); + state2.put("testField", 2); - aggs.add(agg1); - aggs.add(agg2); + states.add(state1); + states.add(state2); - ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, aggs); + ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, states); Object res = script.execute(); assertEquals(3, res); } diff --git a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java index ba45abcc97173..a76a36d82ec78 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java @@ -33,31 +33,31 @@ public class ScriptedMetricAggContexts { private abstract static class ParamsAndAggBase { private final Map params; - private final Object agg; + private final Object state; - ParamsAndAggBase(Map params, Object agg) { + ParamsAndAggBase(Map params, Object state) { this.params = params; - this.agg = agg; + this.state = state; } public Map getParams() { return params; } - public Object getAgg() { - return agg; + public Object getState() { + return state; } } public abstract static class InitScript extends ParamsAndAggBase { - public InitScript(Map params, Object agg) { - super(params, agg); + public InitScript(Map params, Object state) { + super(params, state); } public abstract void execute(); public interface Factory { - InitScript newInstance(Map params, Object agg); + InitScript newInstance(Map params, Object state); } public static String[] PARAMETERS = {}; @@ -68,8 +68,8 @@ public abstract static class MapScript extends ParamsAndAggBase { private final LeafSearchLookup leafLookup; private Scorer scorer; - public MapScript(Map params, Object agg, SearchLookup lookup, LeafReaderContext leafContext) { - super(params, agg); + public MapScript(Map params, Object state, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, state); this.leafLookup = leafContext == null ? null : lookup.getLeafSearchLookup(leafContext); } @@ -110,7 +110,7 @@ public interface LeafFactory { } public interface Factory { - LeafFactory newFactory(Map params, Object agg, SearchLookup lookup); + LeafFactory newFactory(Map params, Object state, SearchLookup lookup); } public static String[] PARAMETERS = new String[] {}; @@ -118,14 +118,14 @@ public interface Factory { } public abstract static class CombineScript extends ParamsAndAggBase { - public CombineScript(Map params, Object agg) { - super(params, agg); + public CombineScript(Map params, Object state) { + super(params, state); } public abstract Object execute(); public interface Factory { - CombineScript newInstance(Map params, Object agg); + CombineScript newInstance(Map params, Object state); } public static String[] PARAMETERS = {}; @@ -134,25 +134,25 @@ public interface Factory { public abstract static class ReduceScript { private final Map params; - private final List aggs; + private final List states; - public ReduceScript(Map params, List aggs) { + public ReduceScript(Map params, List states) { this.params = params; - this.aggs = aggs; + this.states = states; } public Map getParams() { return params; } - public List getAggs() { - return aggs; + public List getStates() { + return states; } public abstract Object execute(); public interface Factory { - ReduceScript newInstance(Map params, List aggs); + ReduceScript newInstance(Map params, List states); } public static String[] PARAMETERS = {}; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index 2a5cf0ecd8320..a9bf3b0f8701a 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -41,14 +41,14 @@ public class ScriptedMetricAggregator extends MetricsAggregator { private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; private final ScriptedMetricAggContexts.CombineScript combineScript; private final Script reduceScript; - private Object agg; + private Object aggState; protected ScriptedMetricAggregator(String name, ScriptedMetricAggContexts.MapScript.LeafFactory mapScript, ScriptedMetricAggContexts.CombineScript combineScript, - Script reduceScript, Object agg, SearchContext context, Aggregator parent, + Script reduceScript, Object aggState, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); - this.agg = agg; + this.aggState = aggState; this.mapScript = mapScript; this.combineScript = combineScript; this.reduceScript = reduceScript; @@ -86,7 +86,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { aggregation = combineScript.execute(); CollectionUtils.ensureNoSelfReferences(aggregation); } else { - aggregation = agg; + aggregation = aggState; } return new InternalScriptedMetric(name, aggregation, reduceScript, pipelineAggregators(), metaData()); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java index 57dab2e7b5a0a..5c669340d6951 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java @@ -80,26 +80,26 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu } // Add _agg to params map for backwards compatibility (redundant with context variables on the scripts created below). - // When this is removed, agg (as passed to ScriptedMetricAggregator) can be changed to Map, since + // When this is removed, aggState (as passed to ScriptedMetricAggregator) can be changed to Map, since // it won't be possible to completely replace it with another type as is possible when it's an entry in params. if (aggParams.containsKey("_agg") == false) { aggParams.put("_agg", new HashMap()); } - Object agg = aggParams.get("_agg"); + Object aggState = aggParams.get("_agg"); final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( - mergeParams(aggParams, initScriptParams), agg); + mergeParams(aggParams, initScriptParams), aggState); final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory( - mergeParams(aggParams, mapScriptParams), agg, lookup); + mergeParams(aggParams, mapScriptParams), aggState, lookup); final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance( - mergeParams(aggParams, combineScriptParams), agg); + mergeParams(aggParams, combineScriptParams), aggState); final Script reduceScript = deepCopyScript(this.reduceScript, context); if (initScript != null) { initScript.execute(); } return new ScriptedMetricAggregator(name, mapScript, - combineScript, reduceScript, agg, context, parent, + combineScript, reduceScript, aggState, context, parent, pipelineAggregators, metaData); } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index e2b39b751a163..05985e52965f9 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -28,7 +28,6 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.script.MockScriptEngine; import org.elasticsearch.script.MockScriptPlugin; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; @@ -195,20 +194,20 @@ protected Map, Object>> pluginScripts() { return newAggregation; }); - scripts.put("agg.items = new ArrayList()", vars -> - aggContextScript(vars, agg -> ((HashMap) agg).put("items", new ArrayList()))); + scripts.put("state.items = new ArrayList()", vars -> + aggContextScript(vars, state -> ((HashMap) state).put("items", new ArrayList()))); - scripts.put("agg.items.add(1)", vars -> - aggContextScript(vars, agg -> { - HashMap aggMap = (HashMap) agg; - List items = (List) aggMap.get("items"); + scripts.put("state.items.add(1)", vars -> + aggContextScript(vars, state -> { + HashMap stateMap = (HashMap) state; + List items = (List) stateMap.get("items"); items.add(1); })); - scripts.put("sum context agg values", vars -> { + scripts.put("sum context state values", vars -> { int sum = 0; - HashMap agg = (HashMap) vars.get("agg"); - List items = (List) agg.get("items"); + HashMap state = (HashMap) vars.get("state"); + List items = (List) state.get("items"); for (Object x : items) { sum += (Integer)x; @@ -217,12 +216,12 @@ protected Map, Object>> pluginScripts() { return sum; }); - scripts.put("sum context aggs of agg values", vars -> { + scripts.put("sum context states", vars -> { Integer sum = 0; - List aggs = (List) vars.get("aggs"); - for (Object agg : (List) aggs) { - sum += ((Number) agg).intValue(); + List states = (List) vars.get("states"); + for (Object state : states) { + sum += ((Number) state).intValue(); } return sum; @@ -236,14 +235,14 @@ static Object aggScript(Map vars, Consumer fn) { } static Object aggContextScript(Map vars, Consumer fn) { - return aggScript(vars, fn, "agg"); + return aggScript(vars, fn, "state"); } @SuppressWarnings("unchecked") - private static Object aggScript(Map vars, Consumer fn, String aggVarName) { - T agg = (T) vars.get(aggVarName); - fn.accept(agg); - return agg; + private static Object aggScript(Map vars, Consumer fn, String stateVarName) { + T aggState = (T) vars.get(stateVarName); + fn.accept(aggState); + return aggState; } } @@ -1060,11 +1059,11 @@ public void testConflictingAggAndScriptParams() { } public void testAggFromContext() { - Script initScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "agg.items = new ArrayList()", Collections.emptyMap()); - Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "agg.items.add(1)", Collections.emptyMap()); - Script combineScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context agg values", Collections.emptyMap()); + Script initScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "state.items = new ArrayList()", Collections.emptyMap()); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "state.items.add(1)", Collections.emptyMap()); + Script combineScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context state values", Collections.emptyMap()); Script reduceScript = - new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context aggs of agg values", + new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context states", Collections.emptyMap()); SearchResponse response = client() diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index d79c19d87db9b..e608bd13d2559 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -193,21 +193,21 @@ public MovingFunctionScript createMovingFunctionScript() { return new MockMovingFunctionScript(); } - public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map params, Object agg) { - return new MockMetricAggInitScript(params, agg, script != null ? script : ctx -> 42d); + public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map params, Object state) { + return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d); } - public ScriptedMetricAggContexts.MapScript.LeafFactory createMetricAggMapScript(Map params, Object agg, + public ScriptedMetricAggContexts.MapScript.LeafFactory createMetricAggMapScript(Map params, Object state, SearchLookup lookup) { - return new MockMetricAggMapScript(params, agg, lookup, script != null ? script : ctx -> 42d); + return new MockMetricAggMapScript(params, state, lookup, script != null ? script : ctx -> 42d); } - public ScriptedMetricAggContexts.CombineScript createMetricAggCombineScript(Map params, Object agg) { - return new MockMetricAggCombineScript(params, agg, script != null ? script : ctx -> 42d); + public ScriptedMetricAggContexts.CombineScript createMetricAggCombineScript(Map params, Object state) { + return new MockMetricAggCombineScript(params, state, script != null ? script : ctx -> 42d); } - public ScriptedMetricAggContexts.ReduceScript createMetricAggReduceScript(Map params, List aggs) { - return new MockMetricAggReduceScript(params, aggs, script != null ? script : ctx -> 42d); + public ScriptedMetricAggContexts.ReduceScript createMetricAggReduceScript(Map params, List states) { + return new MockMetricAggReduceScript(params, states, script != null ? script : ctx -> 42d); } } @@ -366,9 +366,9 @@ public double execute(Query query, Field field, Term term) throws IOException { public static class MockMetricAggInitScript extends ScriptedMetricAggContexts.InitScript { private final Function, Object> script; - MockMetricAggInitScript(Map params, Object agg, + MockMetricAggInitScript(Map params, Object state, Function, Object> script) { - super(params, agg); + super(params, state); this.script = script; } @@ -380,28 +380,28 @@ public void execute() { map.put("params", getParams()); } - map.put("agg", getAgg()); + map.put("state", getState()); script.apply(map); } } public static class MockMetricAggMapScript implements ScriptedMetricAggContexts.MapScript.LeafFactory { private final Map params; - private final Object agg; + private final Object state; private final SearchLookup lookup; private final Function, Object> script; - MockMetricAggMapScript(Map params, Object agg, SearchLookup lookup, + MockMetricAggMapScript(Map params, Object state, SearchLookup lookup, Function, Object> script) { this.params = params; - this.agg = agg; + this.state = state; this.lookup = lookup; this.script = script; } @Override public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext context) { - return new ScriptedMetricAggContexts.MapScript(params, agg, lookup, context) { + return new ScriptedMetricAggContexts.MapScript(params, state, lookup, context) { @Override public void execute() { Map map = new HashMap<>(); @@ -411,7 +411,7 @@ public void execute() { map.put("params", getParams()); } - map.put("agg", getAgg()); + map.put("state", getState()); map.put("doc", getDoc()); map.put("_score", get_score()); @@ -424,9 +424,9 @@ public void execute() { public static class MockMetricAggCombineScript extends ScriptedMetricAggContexts.CombineScript { private final Function, Object> script; - MockMetricAggCombineScript(Map params, Object agg, + MockMetricAggCombineScript(Map params, Object state, Function, Object> script) { - super(params, agg); + super(params, state); this.script = script; } @@ -438,7 +438,7 @@ public Object execute() { map.put("params", getParams()); } - map.put("agg", getAgg()); + map.put("state", getState()); return script.apply(map); } } @@ -446,9 +446,9 @@ public Object execute() { public static class MockMetricAggReduceScript extends ScriptedMetricAggContexts.ReduceScript { private final Function, Object> script; - MockMetricAggReduceScript(Map params, List aggs, + MockMetricAggReduceScript(Map params, List states, Function, Object> script) { - super(params, aggs); + super(params, states); this.script = script; } @@ -460,7 +460,7 @@ public Object execute() { map.put("params", getParams()); } - map.put("aggs", getAggs()); + map.put("states", getStates()); return script.apply(map); } } @@ -475,15 +475,15 @@ public double execute(Map params, double[] values) { return MovingFunctions.unweightedAvg(values); } } - + public class MockScoreScript implements ScoreScript.Factory { - + private final Function, Object> scripts; - + MockScoreScript(Function, Object> scripts) { this.scripts = scripts; } - + @Override public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup) { return new ScoreScript.LeafFactory() { @@ -491,7 +491,7 @@ public ScoreScript.LeafFactory newFactory(Map params, SearchLook public boolean needs_score() { return true; } - + @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { Scorer[] scorerHolder = new Scorer[1]; @@ -505,7 +505,7 @@ public double execute() { } return ((Number) scripts.apply(vars)).doubleValue(); } - + @Override public void setScorer(Scorer scorer) { scorerHolder[0] = scorer; From 46526a14340693012c9667c841fddc82255af91c Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Sun, 10 Jun 2018 11:41:03 -0700 Subject: [PATCH 8/9] Rename a private base class from ...Agg to ...State that I missed in my last commit --- .../script/ScriptedMetricAggContexts.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java index a76a36d82ec78..774dc95d39977 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java @@ -31,11 +31,11 @@ import java.util.Map; public class ScriptedMetricAggContexts { - private abstract static class ParamsAndAggBase { + private abstract static class ParamsAndStateBase { private final Map params; private final Object state; - ParamsAndAggBase(Map params, Object state) { + ParamsAndStateBase(Map params, Object state) { this.params = params; this.state = state; } @@ -49,7 +49,7 @@ public Object getState() { } } - public abstract static class InitScript extends ParamsAndAggBase { + public abstract static class InitScript extends ParamsAndStateBase { public InitScript(Map params, Object state) { super(params, state); } @@ -64,7 +64,7 @@ public interface Factory { public static ScriptContext CONTEXT = new ScriptContext<>("aggs_init", Factory.class); } - public abstract static class MapScript extends ParamsAndAggBase { + public abstract static class MapScript extends ParamsAndStateBase { private final LeafSearchLookup leafLookup; private Scorer scorer; @@ -117,7 +117,7 @@ public interface Factory { public static ScriptContext CONTEXT = new ScriptContext<>("aggs_map", Factory.class); } - public abstract static class CombineScript extends ParamsAndAggBase { + public abstract static class CombineScript extends ParamsAndStateBase { public CombineScript(Map params, Object state) { super(params, state); } From c28eece69907214733d5d0f46e2762833d42014c Mon Sep 17 00:00:00 2001 From: Jonathan Little Date: Fri, 22 Jun 2018 20:29:34 -0700 Subject: [PATCH 9/9] Clean up imports after merge --- .../search/aggregations/metrics/ScriptedMetricIT.java | 1 - .../metrics/scripted/ScriptedMetricAggregatorTests.java | 1 - 2 files changed, 2 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index 05985e52965f9..13e1489795996 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.notNullValue; diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index fb115fb580090..b2a949ceeee1a 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.script.MockScriptEngine; -import org.elasticsearch.script.ScoreAccessor; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptModule;