Skip to content

Commit 0133879

Browse files
winstonewertcolings86
authored andcommitted
Allow scripted metric agg to access _score (#24295)
* Fixes #24259 Corrects the ScriptedMetricAggregator so that the script can have access to scores during the map stage. * Restored original tests. Added seperate test. As requested, I've restored the non-score dependant tests, and added the score dependent metric as a seperate test.
1 parent fb21caf commit 0133879

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public boolean needsScores() {
6363
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
6464
final LeafBucketCollector sub) throws IOException {
6565
final LeafSearchScript leafMapScript = mapScript.getLeafSearchScript(ctx);
66-
return new LeafBucketCollectorBase(sub, mapScript) {
66+
return new LeafBucketCollectorBase(sub, leafMapScript) {
6767
@Override
6868
public void collect(int doc, long bucket) throws IOException {
6969
assert bucket == 0 : bucket;

core/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.index.query.QueryShardContext;
3434
import org.elasticsearch.indices.breaker.CircuitBreakerService;
3535
import org.elasticsearch.script.MockScriptEngine;
36+
import org.elasticsearch.script.ScoreAccessor;
3637
import org.elasticsearch.script.Script;
3738
import org.elasticsearch.script.ScriptContextRegistry;
3839
import org.elasticsearch.script.ScriptEngineRegistry;
@@ -59,6 +60,13 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
5960
private static final Script MAP_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScript", Collections.emptyMap());
6061
private static final Script COMBINE_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScript",
6162
Collections.emptyMap());
63+
64+
private static final Script INIT_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptScore",
65+
Collections.emptyMap());
66+
private static final Script MAP_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptScore",
67+
Collections.emptyMap());
68+
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
69+
Collections.emptyMap());
6270
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
6371

6472

@@ -79,6 +87,21 @@ public static void initMockScripts() {
7987
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
8088
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).sum();
8189
});
90+
91+
SCRIPTS.put("initScriptScore", params -> {
92+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
93+
agg.put("collector", new ArrayList<Double>());
94+
return agg;
95+
});
96+
SCRIPTS.put("mapScriptScore", params -> {
97+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
98+
((List<Double>) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue());
99+
return agg;
100+
});
101+
SCRIPTS.put("combineScriptScore", params -> {
102+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
103+
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
104+
});
82105
}
83106

84107
@SuppressWarnings("unchecked")
@@ -144,6 +167,29 @@ public void testScriptedMetricWithCombine() throws IOException {
144167
}
145168
}
146169

170+
/**
171+
* test that uses the score of the documents
172+
*/
173+
public void testScriptedMetricWithCombineAccessesScores() throws IOException {
174+
try (Directory directory = newDirectory()) {
175+
Integer numDocs = randomInt(100);
176+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
177+
for (int i = 0; i < numDocs; i++) {
178+
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
179+
}
180+
}
181+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
182+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
183+
aggregationBuilder.initScript(INIT_SCRIPT_SCORE).mapScript(MAP_SCRIPT_SCORE).combineScript(COMBINE_SCRIPT_SCORE);
184+
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);
185+
assertEquals(AGG_NAME, scriptedMetric.getName());
186+
assertNotNull(scriptedMetric.aggregation());
187+
// all documents have score of 1.0
188+
assertEquals((double) numDocs, scriptedMetric.aggregation());
189+
}
190+
}
191+
}
192+
147193
/**
148194
* We cannot use Mockito for mocking QueryShardContext in this case because
149195
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)

0 commit comments

Comments
 (0)