Skip to content

Commit 510848f

Browse files
committed
Aggregations Refactor: Refactor Scripted Metric Aggregation
1 parent 21556f9 commit 510848f

File tree

4 files changed

+190
-9
lines changed

4 files changed

+190
-9
lines changed

core/src/main/java/org/elasticsearch/script/Script.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,24 @@
3434

3535
import java.io.IOException;
3636
import java.util.Map;
37+
import java.util.function.Supplier;
3738

3839
/**
3940
* Script holds all the parameters necessary to compile or find in cache and then execute a script.
4041
*/
4142
public class Script implements ToXContent, Streamable {
4243

44+
/**
45+
* A {@link Supplier} implementation for use when reading a {@link Script}
46+
* using {@link StreamInput#readOptionalStreamable(Supplier)}
47+
*/
48+
public static final Supplier<Script> SUPPLIER = new Supplier<Script>() {
49+
50+
@Override
51+
public Script get() {
52+
return new Script();
53+
}
54+
};
4355
public static final ScriptType DEFAULT_TYPE = ScriptType.INLINE;
4456
private static final ScriptParser PARSER = new ScriptParser();
4557

@@ -74,7 +86,7 @@ protected Script(String script, String lang) {
7486

7587
/**
7688
* Constructor for Script.
77-
*
89+
*
7890
* @param script
7991
* The cache key of the script to be compiled/executed. For
8092
* inline scripts this is the actual script source code. For
@@ -112,7 +124,7 @@ public String getScript() {
112124

113125
/**
114126
* Method for getting the type.
115-
*
127+
*
116128
* @return The type of script -- inline, indexed, or file.
117129
*/
118130
public ScriptType getType() {
@@ -121,7 +133,7 @@ public ScriptType getType() {
121133

122134
/**
123135
* Method for getting language.
124-
*
136+
*
125137
* @return The language of the script to be compiled/executed.
126138
*/
127139
public String getLang() {
@@ -130,7 +142,7 @@ public String getLang() {
130142

131143
/**
132144
* Method for getting the parameters.
133-
*
145+
*
134146
* @return The map of parameters the script will be executed with.
135147
*/
136148
public Map<String, Object> getParams() {

core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
package org.elasticsearch.search.aggregations.metrics.scripted;
2121

2222
import org.apache.lucene.index.LeafReaderContext;
23+
import org.elasticsearch.common.io.stream.StreamInput;
24+
import org.elasticsearch.common.io.stream.StreamOutput;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
2326
import org.elasticsearch.script.ExecutableScript;
2427
import org.elasticsearch.script.LeafSearchScript;
2528
import org.elasticsearch.script.Script;
@@ -43,6 +46,7 @@
4346
import java.util.List;
4447
import java.util.Map;
4548
import java.util.Map.Entry;
49+
import java.util.Objects;
4650

4751
public class ScriptedMetricAggregator extends MetricsAggregator {
4852

@@ -113,13 +117,43 @@ public static class Factory extends AggregatorFactory {
113117
private Script reduceScript;
114118
private Map<String, Object> params;
115119

116-
public Factory(String name, Script initScript, Script mapScript, Script combineScript, Script reduceScript,
117-
Map<String, Object> params) {
120+
public Factory(String name) {
118121
super(name, InternalScriptedMetric.TYPE);
122+
}
123+
124+
/**
125+
* Set the <tt>init</tt> script.
126+
*/
127+
public void initScript(Script initScript) {
119128
this.initScript = initScript;
129+
}
130+
131+
/**
132+
* Set the <tt>map</tt> script.
133+
*/
134+
public void mapScript(Script mapScript) {
120135
this.mapScript = mapScript;
136+
}
137+
138+
/**
139+
* Set the <tt>combine</tt> script.
140+
*/
141+
public void combineScript(Script combineScript) {
121142
this.combineScript = combineScript;
143+
}
144+
145+
/**
146+
* Set the <tt>reduce</tt> script.
147+
*/
148+
public void reduceScript(Script reduceScript) {
122149
this.reduceScript = reduceScript;
150+
}
151+
152+
/**
153+
* Set parameters that will be available in the <tt>init</tt>,
154+
* <tt>map</tt> and <tt>combine</tt> phases.
155+
*/
156+
public void params(Map<String, Object> params) {
123157
this.params = params;
124158
}
125159

@@ -188,6 +222,73 @@ private static <T> T deepCopyParams(T original, SearchContext context) {
188222
return clone;
189223
}
190224

225+
@Override
226+
protected XContentBuilder internalXContent(XContentBuilder builder, Params builderParams) throws IOException {
227+
builder.startObject();
228+
if (initScript != null) {
229+
builder.field(ScriptedMetricParser.INIT_SCRIPT_FIELD.getPreferredName(), initScript);
230+
}
231+
232+
if (mapScript != null) {
233+
builder.field(ScriptedMetricParser.MAP_SCRIPT_FIELD.getPreferredName(), mapScript);
234+
}
235+
236+
if (combineScript != null) {
237+
builder.field(ScriptedMetricParser.COMBINE_SCRIPT_FIELD.getPreferredName(), combineScript);
238+
}
239+
240+
if (reduceScript != null) {
241+
builder.field(ScriptedMetricParser.REDUCE_SCRIPT_FIELD.getPreferredName(), reduceScript);
242+
}
243+
if (params != null) {
244+
builder.field(ScriptedMetricParser.PARAMS_FIELD.getPreferredName());
245+
builder.map(params);
246+
}
247+
builder.endObject();
248+
return builder;
249+
}
250+
251+
@Override
252+
protected AggregatorFactory doReadFrom(String name, StreamInput in) throws IOException {
253+
Factory factory = new Factory(name);
254+
factory.initScript = in.readOptionalStreamable(Script.SUPPLIER);
255+
factory.mapScript = in.readOptionalStreamable(Script.SUPPLIER);
256+
factory.combineScript = in.readOptionalStreamable(Script.SUPPLIER);
257+
factory.reduceScript = in.readOptionalStreamable(Script.SUPPLIER);
258+
if (in.readBoolean()) {
259+
factory.params = in.readMap();
260+
}
261+
return factory;
262+
}
263+
264+
@Override
265+
protected void doWriteTo(StreamOutput out) throws IOException {
266+
out.writeOptionalStreamable(initScript);
267+
out.writeOptionalStreamable(mapScript);
268+
out.writeOptionalStreamable(combineScript);
269+
out.writeOptionalStreamable(reduceScript);
270+
boolean hasParams = params != null;
271+
out.writeBoolean(hasParams);
272+
if (hasParams) {
273+
out.writeMap(params);
274+
}
275+
}
276+
277+
@Override
278+
protected int doHashCode() {
279+
return Objects.hash(initScript, mapScript, combineScript, reduceScript, params);
280+
}
281+
282+
@Override
283+
protected boolean doEquals(Object obj) {
284+
Factory other = (Factory) obj;
285+
return Objects.equals(initScript, other.initScript)
286+
&& Objects.equals(mapScript, other.mapScript)
287+
&& Objects.equals(combineScript, other.combineScript)
288+
&& Objects.equals(reduceScript, other.reduceScript)
289+
&& Objects.equals(params, other.params);
290+
}
291+
191292
}
192293

193294
}

core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricParser.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,19 @@ public AggregatorFactory parse(String aggregationName, XContentParser parser, Se
147147
if (mapScript == null) {
148148
throw new SearchParseException(context, "map_script field is required in [" + aggregationName + "].", parser.getTokenLocation());
149149
}
150-
return new ScriptedMetricAggregator.Factory(aggregationName, initScript, mapScript, combineScript, reduceScript, params);
150+
151+
ScriptedMetricAggregator.Factory factory = new ScriptedMetricAggregator.Factory(aggregationName);
152+
factory.initScript(initScript);
153+
factory.mapScript(mapScript);
154+
factory.combineScript(combineScript);
155+
factory.reduceScript(reduceScript);
156+
factory.params(params);
157+
return factory;
151158
}
152159

153-
// NORELEASE implement this method when refactoring this aggregation
154160
@Override
155161
public AggregatorFactory[] getFactoryPrototypes() {
156-
return null;
162+
return new AggregatorFactory[] { new ScriptedMetricAggregator.Factory(null) };
157163
}
158164

159165
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.search.aggregations.metrics;
21+
22+
import org.elasticsearch.script.Script;
23+
import org.elasticsearch.script.ScriptService.ScriptType;
24+
import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
25+
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricAggregator;
26+
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricAggregator.Factory;
27+
28+
import java.util.HashMap;
29+
import java.util.Map;
30+
31+
public class ScriptedMetricTests extends BaseAggregationTestCase<ScriptedMetricAggregator.Factory> {
32+
33+
@Override
34+
protected Factory createTestAggregatorFactory() {
35+
Factory factory = new Factory(randomAsciiOfLengthBetween(1, 20));
36+
if (randomBoolean()) {
37+
factory.initScript(randomScript("initScript"));
38+
}
39+
factory.mapScript(randomScript("mapScript"));
40+
if (randomBoolean()) {
41+
factory.combineScript(randomScript("combineScript"));
42+
}
43+
if (randomBoolean()) {
44+
factory.reduceScript(randomScript("reduceScript"));
45+
}
46+
if (randomBoolean()) {
47+
Map<String, Object> params = new HashMap<String, Object>();
48+
params.put("foo", "bar");
49+
factory.params(params);
50+
}
51+
return factory;
52+
}
53+
54+
private Script randomScript(String script) {
55+
if (randomBoolean()) {
56+
return new Script(script);
57+
} else {
58+
return new Script(script, randomFrom(ScriptType.values()), randomFrom("my_lang", null), null);
59+
}
60+
}
61+
62+
}

0 commit comments

Comments
 (0)