Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.search.aggregations.matrix.stats.InternalMatrixStats.Fields;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.elasticsearch.test.InternalAggregationTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -162,8 +163,8 @@ public void testReduceRandom() {

ScriptService mockScriptService = mockScriptService();
MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
InternalAggregation.ReduceContext context =
new InternalAggregation.ReduceContext(bigArrays, mockScriptService, true);
InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
bigArrays, mockScriptService, b -> {}, PipelineTree.EMPTY);
InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context);
multiPassStats.assertNearlyEqual(reduced.getResults());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.FieldDoc;
Expand Down Expand Up @@ -69,17 +70,13 @@
import java.util.stream.Collectors;

public final class SearchPhaseController {

private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];

private final Function<Boolean, ReduceContext> reduceContextFunction;
private final Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;

/**
* Constructor.
* @param reduceContextFunction A function that builds a context for the reduce of an {@link InternalAggregation}
*/
public SearchPhaseController(Function<Boolean, ReduceContext> reduceContextFunction) {
this.reduceContextFunction = reduceContextFunction;
public SearchPhaseController(
Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder) {
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
}

public AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
Expand Down Expand Up @@ -394,17 +391,30 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr
* @param queryResults a list of non-null query shard results
*/
ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
return reducedQueryPhase(queryResults, true, SearchContext.TRACK_TOTAL_HITS_ACCURATE, true);
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = new InternalAggregation.ReduceContextBuilder() {
@Override
public ReduceContext forPartialReduction() {
throw new UnsupportedOperationException("Scroll requests don't have aggs");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

}

@Override
public ReduceContext forFinalReduction() {
throw new UnsupportedOperationException("Scroll requests don't have aggs");
}
};
return reducedQueryPhase(queryResults, true, SearchContext.TRACK_TOTAL_HITS_ACCURATE, aggReduceContextBuilder, true);
}

/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
boolean isScrollRequest, int trackTotalHitsUpTo, boolean performFinalReduce) {
boolean isScrollRequest, int trackTotalHitsUpTo,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHitsUpTo),
0, isScrollRequest, performFinalReduce);
0, isScrollRequest, aggReduceContextBuilder, performFinalReduce);
}

/**
Expand All @@ -421,6 +431,7 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
numReducePhases++; // increment for this phase
Expand Down Expand Up @@ -496,9 +507,8 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
}
ReduceContext reduceContext = reduceContextFunction.apply(performFinalReduce);
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null :
InternalAggregations.topLevelReduce(aggregationsList, reduceContext);
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(aggregationsList,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size,
reducedCompletionSuggestions);
Expand Down Expand Up @@ -600,6 +610,7 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
private int numReducePhases = 0;
private final TopDocsStats topDocsStats;
private final int topNSize;
private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
private final boolean performFinalReduce;

/**
Expand All @@ -613,7 +624,9 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
*/
private QueryPhaseResultConsumer(SearchProgressListener progressListener, SearchPhaseController controller,
int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
int trackTotalHitsUpTo, int topNSize, boolean performFinalReduce) {
int trackTotalHitsUpTo, int topNSize,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
super(expectedResultSize);
if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
Expand All @@ -635,6 +648,7 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
this.bufferSize = bufferSize;
this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo);
this.topNSize = topNSize;
this.aggReduceContextBuilder = aggReduceContextBuilder;
this.performFinalReduce = performFinalReduce;
}

Expand All @@ -650,7 +664,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) {
if (index == bufferSize) {
if (hasAggs) {
ReduceContext reduceContext = controller.reduceContextFunction.apply(false);
ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction();
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(Arrays.asList(aggsBuffer), reduceContext);
Arrays.fill(aggsBuffer, null);
aggsBuffer[0] = reducedAggs;
Expand Down Expand Up @@ -693,8 +707,8 @@ private synchronized List<TopDocs> getRemainingTopDocs() {

@Override
public ReducedQueryPhase reduce() {
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, performFinalReduce);
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(),
topDocsStats, numReducePhases, false, aggReduceContextBuilder, performFinalReduce);
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
Expand Down Expand Up @@ -730,13 +744,14 @@ ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressL
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final int trackTotalHitsUpTo = resolveTrackTotalHits(request);
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = requestToAggReduceContextBuilder.apply(request);
if (isScrollRequest == false && (hasAggs || hasTopDocs)) {
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
if (request.getBatchedReduceSize() < numShards) {
int topNSize = getTopDocsSize(request);
// only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(listener, this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs,
trackTotalHitsUpTo, topNSize, request.isFinalReduce());
trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
}
}
return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
Expand All @@ -750,7 +765,7 @@ void consumeResult(SearchPhaseResult result) {
ReducedQueryPhase reduce() {
List<SearchPhaseResult> resultList = results.asList();
final ReducedQueryPhase reducePhase =
reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce());
reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, aggReduceContextBuilder, request.isFinalReduce());
listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
import org.elasticsearch.action.search.SearchResponse.Clusters;
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.profile.ProfileShardResult;
Expand All @@ -51,11 +53,8 @@
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Function;

import static org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
import static org.elasticsearch.action.search.SearchResponse.Clusters;

/**
* Merges multiple search responses into one. Used in cross-cluster search when reduction is performed locally on each cluster.
Expand All @@ -81,16 +80,16 @@ final class SearchResponseMerger {
final int size;
final int trackTotalHitsUpTo;
private final SearchTimeProvider searchTimeProvider;
private final Function<Boolean, ReduceContext> reduceContextFunction;
private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
private final List<SearchResponse> searchResponses = new CopyOnWriteArrayList<>();

SearchResponseMerger(int from, int size, int trackTotalHitsUpTo, SearchTimeProvider searchTimeProvider,
Function<Boolean, ReduceContext> reduceContextFunction) {
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder) {
this.from = from;
this.size = size;
this.trackTotalHitsUpTo = trackTotalHitsUpTo;
this.searchTimeProvider = Objects.requireNonNull(searchTimeProvider);
this.reduceContextFunction = Objects.requireNonNull(reduceContextFunction);
this.aggReduceContextBuilder = Objects.requireNonNull(aggReduceContextBuilder);
}

/**
Expand Down Expand Up @@ -196,7 +195,7 @@ SearchResponse getMergedResponse(Clusters clusters) {
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats);
setSuggestShardIndex(shards, groupedSuggestions);
Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, reduceContextFunction.apply(true));
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forFinalReduction());
ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY);
SearchProfileShardResults profileShardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
//make failures ordering consistent between ordinary search and CCS by looking at the shard they come from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ protected void doExecute(Task task, SearchRequest searchRequest, ActionListener<
executeLocalSearch(task, timeProvider, searchRequest, localIndices, clusterState, listener);
} else {
if (shouldMinimizeRoundtrips(searchRequest)) {
ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider, searchService::createReduceContext,
remoteClusterService, threadPool, listener,
(r, l) -> executeLocalSearch(task, timeProvider, r, localIndices, clusterState, l));
ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider,
searchService.aggReduceContextBuilder(searchRequest),
remoteClusterService, threadPool, listener,
(r, l) -> executeLocalSearch(task, timeProvider, r, localIndices, clusterState, l));
} else {
AtomicInteger skippedClusters = new AtomicInteger(0);
collectSearchShards(searchRequest.indicesOptions(), searchRequest.preference(), searchRequest.routing(),
Expand Down Expand Up @@ -260,7 +261,7 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) {
}

static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices,
SearchTimeProvider timeProvider, Function<Boolean, InternalAggregation.ReduceContext> reduceContext,
SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener,
BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {

Expand Down Expand Up @@ -298,7 +299,8 @@ public void onFailure(Exception e) {
}
});
} else {
SearchResponseMerger searchResponseMerger = createSearchResponseMerger(searchRequest.source(), timeProvider, reduceContext);
SearchResponseMerger searchResponseMerger = createSearchResponseMerger(
searchRequest.source(), timeProvider, aggReduceContextBuilder);
AtomicInteger skippedClusters = new AtomicInteger(0);
final AtomicReference<Exception> exceptions = new AtomicReference<>();
int totalClusters = remoteIndices.size() + (localIndices == null ? 0 : 1);
Expand All @@ -325,7 +327,7 @@ public void onFailure(Exception e) {
}

static SearchResponseMerger createSearchResponseMerger(SearchSourceBuilder source, SearchTimeProvider timeProvider,
Function<Boolean, InternalAggregation.ReduceContext> reduceContextFunction) {
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder) {
final int from;
final int size;
final int trackTotalHitsUpTo;
Expand All @@ -342,7 +344,7 @@ static SearchResponseMerger createSearchResponseMerger(SearchSourceBuilder sourc
source.from(0);
source.size(from + size);
}
return new SearchResponseMerger(from, size, trackTotalHitsUpTo, timeProvider, reduceContextFunction);
return new SearchResponseMerger(from, size, trackTotalHitsUpTo, timeProvider, aggReduceContextBuilder);
}

static void collectSearchShards(IndicesOptions indicesOptions, String preference, String routing, AtomicInteger skippedClusters,
Expand Down
2 changes: 1 addition & 1 deletion server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ protected Node(final Environment initialEnvironment,
b.bind(MetaDataCreateIndexService.class).toInstance(metaDataCreateIndexService);
b.bind(SearchService.class).toInstance(searchService);
b.bind(SearchTransportService.class).toInstance(searchTransportService);
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::createReduceContext));
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::aggReduceContextBuilder));
b.bind(Transport.class).toInstance(transport);
b.bind(TransportService.class).toInstance(transportService);
b.bind(NetworkService.class).toInstance(networkService);
Expand Down
Loading