Skip to content

Commit f933f80

Browse files
authored
First step towards incremental reduction of query responses (#23253)
Today all query results are buffered up until we received responses of all shards. This can hold on to a significant amount of memory if the number of shards is large. This commit adds a first step towards incrementally reducing aggregations results if a, per search request, configurable amount of responses are received. If enough query results have been received and buffered all so-far received aggregation responses will be reduced and released to be GCed.
1 parent 39ed76c commit f933f80

33 files changed

+588
-182
lines changed

core/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
import java.util.List;
4444
import java.util.Map;
45-
import java.util.StringJoiner;
4645
import java.util.concurrent.Executor;
4746
import java.util.concurrent.atomic.AtomicInteger;
4847
import java.util.function.Function;
@@ -61,7 +60,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
6160
**/
6261
private final Function<String, Transport.Connection> nodeIdToConnection;
6362
private final SearchTask task;
64-
private final AtomicArray<Result> results;
63+
private final SearchPhaseResults<Result> results;
6564
private final long clusterStateVersion;
6665
private final Map<String, AliasFilter> aliasFilter;
6766
private final Map<String, Float> concreteIndexBoosts;
@@ -76,7 +75,7 @@ protected AbstractSearchAsyncAction(String name, Logger logger, SearchTransportS
7675
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
7776
Executor executor, SearchRequest request,
7877
ActionListener<SearchResponse> listener, GroupShardsIterator shardsIts, long startTime,
79-
long clusterStateVersion, SearchTask task) {
78+
long clusterStateVersion, SearchTask task, SearchPhaseResults<Result> resultConsumer) {
8079
super(name, request, shardsIts, logger);
8180
this.startTime = startTime;
8281
this.logger = logger;
@@ -87,9 +86,9 @@ protected AbstractSearchAsyncAction(String name, Logger logger, SearchTransportS
8786
this.listener = listener;
8887
this.nodeIdToConnection = nodeIdToConnection;
8988
this.clusterStateVersion = clusterStateVersion;
90-
results = new AtomicArray<>(shardsIts.size());
9189
this.concreteIndexBoosts = concreteIndexBoosts;
9290
this.aliasFilter = aliasFilter;
91+
this.results = resultConsumer;
9392
}
9493

9594
/**
@@ -105,7 +104,7 @@ private long buildTookInMillis() {
105104
* This is the main entry point for a search. This method starts the search execution of the initial phase.
106105
*/
107106
public final void start() {
108-
if (results.length() == 0) {
107+
if (getNumShards() == 0) {
109108
//no search shards to search on, bail with empty response
110109
//(it happens with search across _all with no indices around and consistent with broadcast operations)
111110
listener.onResponse(new SearchResponse(InternalSearchResponse.empty(), null, 0, 0, buildTookInMillis(),
@@ -130,8 +129,8 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
130129
onPhaseFailure(currentPhase, "all shards failed", null);
131130
} else {
132131
if (logger.isTraceEnabled()) {
133-
final String resultsFrom = results.asList().stream()
134-
.map(r -> r.value.shardTarget().toString()).collect(Collectors.joining(","));
132+
final String resultsFrom = results.getSuccessfulResults()
133+
.map(r -> r.shardTarget().toString()).collect(Collectors.joining(","));
135134
logger.trace("[{}] Moving to next phase: [{}], based on results from: {} (cluster state version: {})",
136135
currentPhase.getName(), nextPhase.getName(), resultsFrom, clusterStateVersion);
137136
}
@@ -178,7 +177,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
178177
synchronized (shardFailuresMutex) {
179178
shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it?
180179
if (shardFailures == null) { // still null so we are the first and create a new instance
181-
shardFailures = new AtomicArray<>(results.length());
180+
shardFailures = new AtomicArray<>(getNumShards());
182181
this.shardFailures.set(shardFailures);
183182
}
184183
}
@@ -194,7 +193,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
194193
}
195194
}
196195

197-
if (results.get(shardIndex) != null) {
196+
if (results.hasResult(shardIndex)) {
198197
assert failure == null : "shard failed before but shouldn't: " + failure;
199198
successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter
200199
}
@@ -207,22 +206,22 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
207206
* @param exception the exception explaining or causing the phase failure
208207
*/
209208
private void raisePhaseFailure(SearchPhaseExecutionException exception) {
210-
for (AtomicArray.Entry<Result> entry : results.asList()) {
209+
results.getSuccessfulResults().forEach((entry) -> {
211210
try {
212-
Transport.Connection connection = nodeIdToConnection.apply(entry.value.shardTarget().getNodeId());
213-
sendReleaseSearchContext(entry.value.id(), connection);
211+
Transport.Connection connection = nodeIdToConnection.apply(entry.shardTarget().getNodeId());
212+
sendReleaseSearchContext(entry.id(), connection);
214213
} catch (Exception inner) {
215214
inner.addSuppressed(exception);
216215
logger.trace("failed to release context", inner);
217216
}
218-
}
217+
});
219218
listener.onFailure(exception);
220219
}
221220

222221
@Override
223222
public final void onShardSuccess(int shardIndex, Result result) {
224223
successfulOps.incrementAndGet();
225-
results.set(shardIndex, result);
224+
results.consumeResult(shardIndex, result);
226225
if (logger.isTraceEnabled()) {
227226
logger.trace("got first-phase result from {}", result != null ? result.shardTarget() : null);
228227
}
@@ -242,7 +241,7 @@ public final void onPhaseDone() {
242241

243242
@Override
244243
public final int getNumShards() {
245-
return results.length();
244+
return results.getNumShards();
246245
}
247246

248247
@Override
@@ -262,7 +261,7 @@ public final SearchRequest getRequest() {
262261

263262
@Override
264263
public final SearchResponse buildSearchResponse(InternalSearchResponse internalSearchResponse, String scrollId) {
265-
return new SearchResponse(internalSearchResponse, scrollId, results.length(), successfulOps.get(),
264+
return new SearchResponse(internalSearchResponse, scrollId, getNumShards(), successfulOps.get(),
266265
buildTookInMillis(), buildShardFailures());
267266
}
268267

@@ -310,6 +309,5 @@ public final ShardSearchTransportRequest buildShardSearchRequest(ShardIterator s
310309
* executed shard request
311310
* @param context the search context for the next phase
312311
*/
313-
protected abstract SearchPhase getNextPhase(AtomicArray<Result> results, SearchPhaseContext context);
314-
312+
protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);
315313
}

core/src/main/java/org/elasticsearch/action/search/CountedCollector.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
package org.elasticsearch.action.search;
2020

2121
import org.elasticsearch.common.Nullable;
22-
import org.elasticsearch.common.util.concurrent.AtomicArray;
2322
import org.elasticsearch.common.util.concurrent.CountDown;
2423
import org.elasticsearch.search.SearchPhaseResult;
2524
import org.elasticsearch.search.SearchShardTarget;
@@ -30,17 +29,13 @@
3029
* where the given index is used to set the result on the array.
3130
*/
3231
final class CountedCollector<R extends SearchPhaseResult> {
33-
private final AtomicArray<R> resultArray;
32+
private final ResultConsumer<R> resultConsumer;
3433
private final CountDown counter;
3534
private final Runnable onFinish;
3635
private final SearchPhaseContext context;
3736

38-
CountedCollector(AtomicArray<R> resultArray, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
39-
if (expectedOps > resultArray.length()) {
40-
throw new IllegalStateException("unexpected number of operations. got: " + expectedOps + " but array size is: "
41-
+ resultArray.length());
42-
}
43-
this.resultArray = resultArray;
37+
CountedCollector(ResultConsumer<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
38+
this.resultConsumer = resultConsumer;
4439
this.counter = new CountDown(expectedOps);
4540
this.onFinish = onFinish;
4641
this.context = context;
@@ -63,7 +58,7 @@ void countDown() {
6358
void onResult(int index, R result, SearchShardTarget target) {
6459
try {
6560
result.shardTarget(target);
66-
resultArray.set(index, result);
61+
resultConsumer.consume(index, result);
6762
} finally {
6863
countDown();
6964
}
@@ -80,4 +75,12 @@ void onFailure(final int shardIndex, @Nullable SearchShardTarget shardTarget, Ex
8075
countDown();
8176
}
8277
}
78+
79+
/**
80+
* A functional interface to plug in shard result consumers to this collector
81+
*/
82+
@FunctionalInterface
83+
public interface ResultConsumer<R extends SearchPhaseResult> {
84+
void consume(int shardIndex, R result);
85+
}
8386
}

core/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,19 @@
4040
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
4141
*/
4242
final class DfsQueryPhase extends SearchPhase {
43-
private final AtomicArray<QuerySearchResultProvider> queryResult;
43+
private final InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> queryResult;
4444
private final SearchPhaseController searchPhaseController;
4545
private final AtomicArray<DfsSearchResult> dfsSearchResults;
46-
private final Function<AtomicArray<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory;
46+
private final Function<InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory;
4747
private final SearchPhaseContext context;
4848
private final SearchTransportService searchTransportService;
4949

5050
DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
5151
SearchPhaseController searchPhaseController,
52-
Function<AtomicArray<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory, SearchPhaseContext context) {
52+
Function<InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory,
53+
SearchPhaseContext context) {
5354
super("dfs_query");
54-
this.queryResult = new AtomicArray<>(dfsSearchResults.length());
55+
this.queryResult = searchPhaseController.newSearchPhaseResults(context.getRequest(), context.getNumShards());
5556
this.searchPhaseController = searchPhaseController;
5657
this.dfsSearchResults = dfsSearchResults;
5758
this.nextPhaseFactory = nextPhaseFactory;
@@ -64,7 +65,8 @@ public void run() throws IOException {
6465
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
6566
// to free up memory early
6667
final AggregatedDfs dfs = searchPhaseController.aggregateDfs(dfsSearchResults);
67-
final CountedCollector<QuerySearchResultProvider> counter = new CountedCollector<>(queryResult, dfsSearchResults.asList().size(),
68+
final CountedCollector<QuerySearchResultProvider> counter = new CountedCollector<>(queryResult::consumeResult,
69+
dfsSearchResults.asList().size(),
6870
() -> {
6971
context.executeNextPhase(this, nextPhaseFactory.apply(queryResult));
7072
}, context);

core/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,31 @@ final class FetchSearchPhase extends SearchPhase {
4949
private final Function<SearchResponse, SearchPhase> nextPhaseFactory;
5050
private final SearchPhaseContext context;
5151
private final Logger logger;
52+
private final InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer;
5253

53-
FetchSearchPhase(AtomicArray<QuerySearchResultProvider> queryResults,
54+
FetchSearchPhase(InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer,
5455
SearchPhaseController searchPhaseController,
5556
SearchPhaseContext context) {
56-
this(queryResults, searchPhaseController, context,
57+
this(resultConsumer, searchPhaseController, context,
5758
(response) -> new ExpandSearchPhase(context, response, // collapse only happens if the request has inner hits
5859
(finalResponse) -> sendResponsePhase(finalResponse, context)));
5960
}
6061

61-
FetchSearchPhase(AtomicArray<QuerySearchResultProvider> queryResults,
62+
FetchSearchPhase(InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer,
6263
SearchPhaseController searchPhaseController,
6364
SearchPhaseContext context, Function<SearchResponse, SearchPhase> nextPhaseFactory) {
6465
super("fetch");
65-
if (context.getNumShards() != queryResults.length()) {
66+
if (context.getNumShards() != resultConsumer.getNumShards()) {
6667
throw new IllegalStateException("number of shards must match the length of the query results but doesn't:"
67-
+ context.getNumShards() + "!=" + queryResults.length());
68+
+ context.getNumShards() + "!=" + resultConsumer.getNumShards());
6869
}
69-
this.fetchResults = new AtomicArray<>(queryResults.length());
70+
this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards());
7071
this.searchPhaseController = searchPhaseController;
71-
this.queryResults = queryResults;
72+
this.queryResults = resultConsumer.results;
7273
this.nextPhaseFactory = nextPhaseFactory;
7374
this.context = context;
7475
this.logger = context.getLogger();
76+
this.resultConsumer = resultConsumer;
7577

7678
}
7779

@@ -99,7 +101,7 @@ private void innerRun() throws IOException {
99101
ScoreDoc[] sortedShardDocs = searchPhaseController.sortDocs(isScrollSearch, queryResults);
100102
String scrollId = isScrollSearch ? TransportSearchHelper.buildScrollId(queryResults) : null;
101103
List<AtomicArray.Entry<QuerySearchResultProvider>> queryResultsAsList = queryResults.asList();
102-
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResultsAsList);
104+
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
103105
final boolean queryAndFetchOptimization = queryResults.length() == 1;
104106
final Runnable finishPhase = ()
105107
-> moveToNextPhase(searchPhaseController, sortedShardDocs, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
@@ -119,7 +121,7 @@ private void innerRun() throws IOException {
119121
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
120122
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, sortedShardDocs, numShards)
121123
: null;
122-
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults,
124+
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults::set,
123125
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
124126
finishPhase, context);
125127
for (int i = 0; i < docIdsToLoad.length; i++) {

core/src/main/java/org/elasticsearch/action/search/InitialSearchPhase.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
import org.elasticsearch.cluster.routing.ShardIterator;
2929
import org.elasticsearch.cluster.routing.ShardRouting;
3030
import org.elasticsearch.common.Nullable;
31+
import org.elasticsearch.common.util.concurrent.AtomicArray;
3132
import org.elasticsearch.search.SearchPhaseResult;
3233
import org.elasticsearch.search.SearchShardTarget;
3334
import org.elasticsearch.transport.ConnectTransportException;
3435

3536
import java.io.IOException;
3637
import java.util.concurrent.atomic.AtomicInteger;
38+
import java.util.stream.Stream;
3739

3840
/**
3941
* This is an abstract base class that encapsulates the logic to fan out to all shards in provided {@link GroupShardsIterator}
@@ -213,4 +215,53 @@ private void onShardResult(int shardIndex, String nodeId, FirstResult result, Sh
213215
* @param listener the listener to notify on response
214216
*/
215217
protected abstract void executePhaseOnShard(ShardIterator shardIt, ShardRouting shard, ActionListener<FirstResult> listener);
218+
219+
/**
220+
* This class acts as a basic result collection that can be extended to do on-the-fly reduction or result processing
221+
*/
222+
static class SearchPhaseResults<Result extends SearchPhaseResult> {
223+
final AtomicArray<Result> results;
224+
225+
SearchPhaseResults(int size) {
226+
results = new AtomicArray<>(size);
227+
}
228+
229+
/**
230+
* Returns the number of expected results this class should collect
231+
*/
232+
final int getNumShards() {
233+
return results.length();
234+
}
235+
236+
/**
237+
* A stream of all non-null (successful) shard results
238+
*/
239+
final Stream<Result> getSuccessfulResults() {
240+
return results.asList().stream().map(e -> e.value);
241+
}
242+
243+
/**
244+
* Consumes a single shard result
245+
* @param shardIndex the shards index, this is a 0-based id that is used to establish a 1 to 1 mapping to the searched shards
246+
* @param result the shards result
247+
*/
248+
void consumeResult(int shardIndex, Result result) {
249+
assert results.get(shardIndex) == null : "shardIndex: " + shardIndex + " is already set";
250+
results.set(shardIndex, result);
251+
}
252+
253+
/**
254+
* Returns <code>true</code> iff a result if present for the given shard ID.
255+
*/
256+
final boolean hasResult(int shardIndex) {
257+
return results.get(shardIndex) != null;
258+
}
259+
260+
/**
261+
* Reduces the collected results
262+
*/
263+
SearchPhaseController.ReducedQueryPhase reduce() {
264+
throw new UnsupportedOperationException("reduce is not supported");
265+
}
266+
}
216267
}

core/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.cluster.routing.GroupShardsIterator;
2525
import org.elasticsearch.cluster.routing.ShardIterator;
2626
import org.elasticsearch.cluster.routing.ShardRouting;
27-
import org.elasticsearch.common.util.concurrent.AtomicArray;
2827
import org.elasticsearch.search.dfs.DfsSearchResult;
2928
import org.elasticsearch.search.internal.AliasFilter;
3029
import org.elasticsearch.transport.Transport;
@@ -43,7 +42,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
4342
ActionListener<SearchResponse> listener, GroupShardsIterator shardsIts, long startTime,
4443
long clusterStateVersion, SearchTask task) {
4544
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, executor,
46-
request, listener, shardsIts, startTime, clusterStateVersion, task);
45+
request, listener, shardsIts, startTime, clusterStateVersion, task, new SearchPhaseResults<>(shardsIts.size()));
4746
this.searchPhaseController = searchPhaseController;
4847
}
4948

@@ -54,8 +53,8 @@ protected void executePhaseOnShard(ShardIterator shardIt, ShardRouting shard, Ac
5453
}
5554

5655
@Override
57-
protected SearchPhase getNextPhase(AtomicArray<DfsSearchResult> results, SearchPhaseContext context) {
58-
return new DfsQueryPhase(results, searchPhaseController,
56+
protected SearchPhase getNextPhase(SearchPhaseResults<DfsSearchResult> results, SearchPhaseContext context) {
57+
return new DfsQueryPhase(results.results, searchPhaseController,
5958
(queryResults) -> new FetchSearchPhase(queryResults, searchPhaseController, context), context);
6059
}
6160
}

core/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,5 @@ default void sendReleaseSearchContext(long contextId, Transport.Connection conne
114114
* a response is returned to the user indicating that all shards have failed.
115115
*/
116116
void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase);
117+
117118
}

0 commit comments

Comments
 (0)