3333import org .elasticsearch .index .query .QueryShardContext ;
3434import org .elasticsearch .indices .breaker .CircuitBreakerService ;
3535import org .elasticsearch .script .MockScriptEngine ;
36+ import org .elasticsearch .script .ScoreAccessor ;
3637import org .elasticsearch .script .Script ;
3738import org .elasticsearch .script .ScriptContextRegistry ;
3839import org .elasticsearch .script .ScriptEngineRegistry ;
@@ -59,6 +60,13 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
5960 private static final Script MAP_SCRIPT = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "mapScript" , Collections .emptyMap ());
6061 private static final Script COMBINE_SCRIPT = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "combineScript" ,
6162 Collections .emptyMap ());
63+
64+ private static final Script INIT_SCRIPT_SCORE = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "initScriptScore" ,
65+ Collections .emptyMap ());
66+ private static final Script MAP_SCRIPT_SCORE = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "mapScriptScore" ,
67+ Collections .emptyMap ());
68+ private static final Script COMBINE_SCRIPT_SCORE = new Script (ScriptType .INLINE , MockScriptEngine .NAME , "combineScriptScore" ,
69+ Collections .emptyMap ());
6270 private static final Map <String , Function <Map <String , Object >, Object >> SCRIPTS = new HashMap <>();
6371
6472
@@ -79,6 +87,21 @@ public static void initMockScripts() {
7987 Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
8088 return ((List <Integer >) agg .get ("collector" )).stream ().mapToInt (Integer ::intValue ).sum ();
8189 });
90+
91+ SCRIPTS .put ("initScriptScore" , params -> {
92+ Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
93+ agg .put ("collector" , new ArrayList <Double >());
94+ return agg ;
95+ });
96+ SCRIPTS .put ("mapScriptScore" , params -> {
97+ Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
98+ ((List <Double >) agg .get ("collector" )).add (((ScoreAccessor ) params .get ("_score" )).doubleValue ());
99+ return agg ;
100+ });
101+ SCRIPTS .put ("combineScriptScore" , params -> {
102+ Map <String , Object > agg = (Map <String , Object >) params .get ("_agg" );
103+ return ((List <Double >) agg .get ("collector" )).stream ().mapToDouble (Double ::doubleValue ).sum ();
104+ });
82105 }
83106
84107 @ SuppressWarnings ("unchecked" )
@@ -144,6 +167,29 @@ public void testScriptedMetricWithCombine() throws IOException {
144167 }
145168 }
146169
170+ /**
171+ * test that uses the score of the documents
172+ */
173+ public void testScriptedMetricWithCombineAccessesScores () throws IOException {
174+ try (Directory directory = newDirectory ()) {
175+ Integer numDocs = randomInt (100 );
176+ try (RandomIndexWriter indexWriter = new RandomIndexWriter (random (), directory )) {
177+ for (int i = 0 ; i < numDocs ; i ++) {
178+ indexWriter .addDocument (singleton (new SortedNumericDocValuesField ("number" , i )));
179+ }
180+ }
181+ try (IndexReader indexReader = DirectoryReader .open (directory )) {
182+ ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder (AGG_NAME );
183+ aggregationBuilder .initScript (INIT_SCRIPT_SCORE ).mapScript (MAP_SCRIPT_SCORE ).combineScript (COMBINE_SCRIPT_SCORE );
184+ ScriptedMetric scriptedMetric = search (newSearcher (indexReader , true , true ), new MatchAllDocsQuery (), aggregationBuilder );
185+ assertEquals (AGG_NAME , scriptedMetric .getName ());
186+ assertNotNull (scriptedMetric .aggregation ());
187+ // all documents have score of 1.0
188+ assertEquals ((double ) numDocs , scriptedMetric .aggregation ());
189+ }
190+ }
191+ }
192+
147193 /**
148194 * We cannot use Mockito for mocking QueryShardContext in this case because
149195 * script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
0 commit comments