3333import org .elasticsearch .index .query .QueryShardContext ;
3434import org .elasticsearch .indices .breaker .CircuitBreakerService ;
3535import org .elasticsearch .script .MockScriptEngine ;
36+ import org .elasticsearch .script .ScoreAccessor ;
3637import org .elasticsearch .script .Script ;
3738import org .elasticsearch .script .ScriptContextRegistry ;
3839import org .elasticsearch .script .ScriptEngineRegistry ;
@@ -59,6 +60,11 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
5960 private static final Script MAP_SCRIPT = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "mapScript" , Collections .emptyMap ());
6061 private static final Script COMBINE_SCRIPT = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "combineScript" ,
6162 Collections .emptyMap ());
63+
64+ private static final Script INIT_SCRIPT_SCORE = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "initScriptScore" , Collections .emptyMap ());
65+ private static final Script MAP_SCRIPT_SCORE = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "mapScriptScore" , Collections .emptyMap ());
66+ private static final Script COMBINE_SCRIPT_SCORE = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "combineScriptScore" ,
67+ Collections .emptyMap ());
6268 private static final Map <String , Function <Map <String , Object >, Object >> SCRIPTS = new HashMap <>();
6369
6470
@@ -79,6 +85,21 @@ public static void initMockScripts() {
7985 Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
8086 return ((List <Integer >) agg .get ("collector" )).stream ().mapToInt (Integer ::intValue ).sum ();
8187 });
88+
89+ SCRIPTS .put ("initScriptScore" , params -> {
90+ Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
91+ agg .put ("collector" , new ArrayList <Double >());
92+ return agg ;
93+ });
94+ SCRIPTS .put ("mapScriptScore" , params -> {
95+ Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
96+ ((List <Double >) agg .get ("collector" )).add (((ScoreAccessor ) params .get ("_score" )).doubleValue ());
97+ return agg ;
98+ });
99+ SCRIPTS .put ("combineScriptScore" , params -> {
100+ Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
101+ return ((List <Double >) agg .get ("collector" )).stream ().mapToDouble (Double ::doubleValue ).sum ();
102+ });
82103 }
83104
84105 @ SuppressWarnings ("unchecked" )
@@ -144,6 +165,29 @@ public void testScriptedMetricWithCombine() throws IOException {
144165 }
145166 }
146167
168+ /**
169+ * test that uses the score of the documents
170+ */
171+ public void testScriptedMetricWithCombineAccessesScores () throws IOException {
172+ try (Directory directory = newDirectory ()) {
173+ Integer numDocs = randomInt (100 );
174+ try (RandomIndexWriter indexWriter = new RandomIndexWriter (random (), directory )) {
175+ for (int i = 0 ; i < numDocs ; i ++) {
176+ indexWriter .addDocument (singleton (new SortedNumericDocValuesField ("number" , i )));
177+ }
178+ }
179+ try (IndexReader indexReader = DirectoryReader .open (directory )) {
180+ ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder (AGG_NAME );
181+ aggregationBuilder .initScript (INIT_SCRIPT_SCORE ).mapScript (MAP_SCRIPT_SCORE ).combineScript (COMBINE_SCRIPT_SCORE );
182+ ScriptedMetric scriptedMetric = search (newSearcher (indexReader , true , true ), new MatchAllDocsQuery (), aggregationBuilder );
183+ assertEquals (AGG_NAME , scriptedMetric .getName ());
184+ assertNotNull (scriptedMetric .aggregation ());
185+ // all documents have score of 1.0
186+ assertEquals ((double ) numDocs , scriptedMetric .aggregation ());
187+ }
188+ }
189+ }
190+
147191 /**
148192 * We cannot use Mockito for mocking QueryShardContext in this case because
149193 * script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
0 commit comments