Skip to content

Commit a325129

Browse files
committed
Handle trackTotalHitsUpTo and disabling local hits tracking
Adapt TopDocsStats so it can be reused.
1 parent 205f0aa commit a325129

File tree

4 files changed

+58
-31
lines changed

4 files changed

+58
-31
lines changed

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
443443
Boolean terminatedEarly = null;
444444
if (queryResults.isEmpty()) { // early terminate we have nothing to reduce
445445
final TotalHits totalHits = topDocsStats.getTotalHits();
446-
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
446+
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
447447
timedOut, terminatedEarly, null, null, null, SortedTopDocs.EMPTY, null, numReducePhases, 0, 0, true);
448448
}
449449
final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult();
@@ -508,7 +508,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
508508
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
509509
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size);
510510
final TotalHits totalHits = topDocsStats.getTotalHits();
511-
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
511+
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
512512
timedOut, terminatedEarly, suggest, aggregations, shardResults, sortedTopDocs,
513513
firstResult.sortValueFormats(), numReducePhases, size, from, false);
514514
}
@@ -577,11 +577,7 @@ public static final class ReducedQueryPhase {
577577
}
578578
this.totalHits = totalHits;
579579
this.fetchHits = fetchHits;
580-
if (Float.isInfinite(maxScore)) {
581-
this.maxScore = Float.NaN;
582-
} else {
583-
this.maxScore = maxScore;
584-
}
580+
this.maxScore = maxScore;
585581
this.timedOut = timedOut;
586582
this.terminatedEarly = terminatedEarly;
587583
this.suggest = suggest;
@@ -744,7 +740,7 @@ static final class TopDocsStats {
744740
private long totalHits;
745741
private TotalHits.Relation totalHitsRelation;
746742
long fetchHits;
747-
float maxScore = Float.NEGATIVE_INFINITY;
743+
private float maxScore = Float.NEGATIVE_INFINITY;
748744

749745
TopDocsStats() {
750746
this(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
@@ -756,6 +752,10 @@ static final class TopDocsStats {
756752
this.totalHitsRelation = Relation.EQUAL_TO;
757753
}
758754

755+
float getMaxScore() {
756+
return Float.isInfinite(maxScore) ? Float.NaN : maxScore;
757+
}
758+
759759
TotalHits getTotalHits() {
760760
if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
761761
return null;

server/src/main/java/org/elasticsearch/action/search/SearchResponseMerger.java

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
2929
import org.elasticsearch.ElasticsearchException;
3030
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
31+
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
3132
import org.elasticsearch.index.shard.ShardId;
3233
import org.elasticsearch.search.SearchHit;
3334
import org.elasticsearch.search.SearchHits;
@@ -52,6 +53,8 @@
5253
import java.util.concurrent.CopyOnWriteArrayList;
5354
import java.util.function.Function;
5455

56+
import static org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
57+
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
5558
import static org.elasticsearch.action.search.SearchResponse.Clusters;
5659

5760
/**
@@ -66,15 +69,17 @@
6669
final class SearchResponseMerger {
6770
private final int from;
6871
private final int size;
72+
int trackTotalHitsUpTo;
6973
private final SearchTimeProvider searchTimeProvider;
7074
private final Clusters clusters;
7175
private final Function<Boolean, ReduceContext> reduceContextFunction;
7276
private final List<SearchResponse> searchResponses = new CopyOnWriteArrayList<>();
7377

74-
SearchResponseMerger(int from, int size, SearchTimeProvider searchTimeProvider, Clusters clusters,
78+
SearchResponseMerger(int from, int size, int trackTotalHitsUpTo, SearchTimeProvider searchTimeProvider, Clusters clusters,
7579
Function<Boolean, ReduceContext> reduceContextFunction) {
7680
this.from = from;
7781
this.size = size;
82+
this.trackTotalHitsUpTo = trackTotalHitsUpTo;
7883
this.searchTimeProvider = Objects.requireNonNull(searchTimeProvider);
7984
this.clusters = Objects.requireNonNull(clusters);
8085
this.reduceContextFunction = Objects.requireNonNull(reduceContextFunction);
@@ -102,7 +107,6 @@ SearchResponse getMergedResponse() {
102107
Boolean terminatedEarly = null;
103108
//the current reduce phase counts as one
104109
int numReducePhases = 1;
105-
float maxScore = Float.NEGATIVE_INFINITY;
106110
List<ShardSearchFailure> failures = new ArrayList<>();
107111
Map<String, ProfileShardResult> profileResults = new HashMap<>();
108112
List<InternalAggregations> aggs = new ArrayList<>();
@@ -111,6 +115,8 @@ SearchResponse getMergedResponse() {
111115
Map<String, List<Suggest.Suggestion>> groupedSuggestions = new HashMap<>();
112116
Boolean trackTotalHits = null;
113117

118+
TopDocsStats topDocsStats = new TopDocsStats(trackTotalHitsUpTo);
119+
114120
for (SearchResponse searchResponse : searchResponses) {
115121
totalShards += searchResponse.getTotalShards();
116122
skippedShards += searchResponse.getSkippedShards();
@@ -139,12 +145,10 @@ SearchResponse getMergedResponse() {
139145
}
140146

141147
SearchHits searchHits = searchResponse.getHits();
142-
if (Float.isNaN(searchHits.getMaxScore()) == false) {
143-
maxScore = Math.max(maxScore, searchHits.getMaxScore());
144-
}
148+
145149
final TotalHits totalHits;
146150
if (searchHits.getTotalHits() == null) {
147-
//in case we did't track total hits, we get null from each cluster, but we need to set 0 eq to the TopDocs
151+
//in case we didn't track total hits, we get null from each cluster, but we need to set 0 eq to the TopDocs
148152
totalHits = new TotalHits(0, TotalHits.Relation.EQUAL_TO);
149153
assert trackTotalHits == null || trackTotalHits == false;
150154
trackTotalHits = false;
@@ -153,7 +157,9 @@ SearchResponse getMergedResponse() {
153157
assert trackTotalHits == null || trackTotalHits;
154158
trackTotalHits = true;
155159
}
156-
topDocsList.add(searchHitsToTopDocs(searchHits, totalHits, shards));
160+
TopDocs topDocs = searchHitsToTopDocs(searchHits, totalHits, shards);
161+
topDocsStats.add(new TopDocsAndMaxScore(topDocs, searchHits.getMaxScore()));
162+
topDocsList.add(topDocs);
157163
}
158164

159165
//now that we've gone through all the hits and we collected all the shards they come from, we can assign shardIndex to each shard
@@ -165,13 +171,15 @@ SearchResponse getMergedResponse() {
165171
for (TopDocs topDocs : topDocsList) {
166172
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
167173
FieldDocAndSearchHit fieldDocAndSearchHit = (FieldDocAndSearchHit) scoreDoc;
174+
//When hits come from the indices with same names on multiple clusters and same shard identifier, we rely on such indices
175+
//to have a different uuid across multiple clusters. That's how they will get a different shardIndex.
168176
ShardId shardId = fieldDocAndSearchHit.searchHit.getShard().getShardId();
169177
fieldDocAndSearchHit.shardIndex = shards.get(shardId);
170178
}
171179
}
172180

173-
TopDocs topDocs = SearchPhaseController.mergeTopDocs(topDocsList, size, from);
174-
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, Float.isInfinite(maxScore) ? Float.NaN : maxScore, trackTotalHits);
181+
TopDocs topDocs = mergeTopDocs(topDocsList, size, from);
182+
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats);
175183
Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
176184
InternalAggregations reducedAggs = InternalAggregations.reduce(aggs, reduceContextFunction.apply(true));
177185
ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY);
@@ -250,7 +258,7 @@ private static TopDocs searchHitsToTopDocs(SearchHits searchHits, TotalHits tota
250258
return topDocs;
251259
}
252260

253-
private static SearchHits topDocsToSearchHits(TopDocs topDocs, float maxScore, boolean trackTotalHits) {
261+
private static SearchHits topDocsToSearchHits(TopDocs topDocs, TopDocsStats topDocsStats) {
254262
SearchHit[] searchHits = new SearchHit[topDocs.scoreDocs.length];
255263
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
256264
FieldDocAndSearchHit scoreDoc = (FieldDocAndSearchHit)topDocs.scoreDocs[i];
@@ -268,9 +276,8 @@ private static SearchHits topDocsToSearchHits(TopDocs topDocs, float maxScore, b
268276
collapseValues = collapseTopFieldDocs.collapseValues;
269277
}
270278
}
271-
//in case we didn't track total hits, we got null from each cluster, and we need to set null to the final response
272-
final TotalHits totalHits = trackTotalHits ? topDocs.totalHits : null;
273-
return new SearchHits(searchHits, totalHits, maxScore, sortFields, collapseField, collapseValues);
279+
return new SearchHits(searchHits, topDocsStats.getTotalHits(), topDocsStats.getMaxScore(),
280+
sortFields, collapseField, collapseValues);
274281
}
275282

276283
private static void setShardIndex(Collection<List<FieldDoc>> shardResults) {

server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public void testSortIsIdempotent() throws Exception {
139139
assertEquals(sortedDocs[i].shardIndex, sortedDocs2[i].shardIndex);
140140
assertEquals(sortedDocs[i].score, sortedDocs2[i].score, 0.0f);
141141
}
142-
assertEquals(topDocsStats.maxScore, topDocsStats2.maxScore, 0.0f);
142+
assertEquals(topDocsStats.getMaxScore(), topDocsStats2.getMaxScore(), 0.0f);
143143
assertEquals(topDocsStats.getTotalHits().value, topDocsStats2.getTotalHits().value);
144144
assertEquals(topDocsStats.getTotalHits().relation, topDocsStats2.getTotalHits().relation);
145145
assertEquals(topDocsStats.fetchHits, topDocsStats2.fetchHits);

server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.search.SearchHits;
3333
import org.elasticsearch.search.SearchShardTarget;
3434
import org.elasticsearch.search.internal.InternalSearchResponse;
35+
import org.elasticsearch.search.internal.SearchContext;
3536
import org.elasticsearch.search.profile.ProfileShardResult;
3637
import org.elasticsearch.search.profile.SearchProfileShardResults;
3738
import org.elasticsearch.search.profile.SearchProfileShardResultsTests;
@@ -84,7 +85,7 @@ public void testMergeTookInMillis() throws InterruptedException {
8485
SearchTimeProvider timeProvider = new SearchTimeProvider(randomLong(), 0, () -> currentRelativeTime);
8586
SearchResponse.Clusters clusters = SearchResponseTests.randomClusters();
8687
SearchResponseMerger merger = new SearchResponseMerger(randomIntBetween(0, 1000), randomIntBetween(0, 10000),
87-
timeProvider, clusters, flag -> null);
88+
SearchContext.TRACK_TOTAL_HITS_ACCURATE, timeProvider, clusters, flag -> null);
8889
for (int i = 0; i < numResponses; i++) {
8990
SearchResponse searchResponse = new SearchResponse(InternalSearchResponse.empty(), null, 1, 1, 0, randomLong(),
9091
ShardSearchFailure.EMPTY_ARRAY, SearchResponseTests.randomClusters());
@@ -97,7 +98,8 @@ public void testMergeTookInMillis() throws InterruptedException {
9798

9899
public void testMergeShardFailures() throws InterruptedException {
99100
SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
100-
SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
101+
SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
102+
searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
101103
PriorityQueue<Tuple<ShardId, ShardSearchFailure>> priorityQueue = new PriorityQueue<>(Comparator.comparing(Tuple::v1));
102104
int numIndices = numResponses * randomIntBetween(1, 3);
103105
Iterator<Map.Entry<String, Index[]>> indicesPerCluster = randomRealisticIndices(numIndices, numResponses).entrySet().iterator();
@@ -136,7 +138,8 @@ public void testMergeShardFailures() throws InterruptedException {
136138

137139
public void testMergeShardFailuresNullShardId() throws InterruptedException {
138140
SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
139-
SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
141+
SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
142+
searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
140143
List<ShardSearchFailure> expectedFailures = new ArrayList<>();
141144
for (int i = 0; i < numResponses; i++) {
142145
int numFailures = randomIntBetween(1, 50);
@@ -157,7 +160,8 @@ public void testMergeShardFailuresNullShardId() throws InterruptedException {
157160

158161
public void testMergeProfileResults() throws InterruptedException {
159162
SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
160-
SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
163+
SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
164+
searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
161165
Map<String, ProfileShardResult> expectedProfile = new HashMap<>();
162166
for (int i = 0; i < numResponses; i++) {
163167
SearchProfileShardResults profile = SearchProfileShardResultsTests.createTestItem();
@@ -206,10 +210,14 @@ public void testMergeSearchHits() throws InterruptedException {
206210
sortFields = null;
207211
scoreSort = true;
208212
}
209-
TotalHits.Relation totalHitsRelation = frequently() ? randomFrom(TotalHits.Relation.values()) : null;
213+
Tuple<Integer, TotalHits.Relation> randomTrackTotalHits = randomTrackTotalHits();
214+
int trackTotalHitsUpTo = randomTrackTotalHits.v1();
215+
TotalHits.Relation totalHitsRelation = randomTrackTotalHits.v2();
210216

211217
PriorityQueue<SearchHit> priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields));
212-
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, timeProvider, clusters, flag -> null);
218+
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, trackTotalHitsUpTo,
219+
timeProvider, clusters, flag -> null);
220+
213221
TotalHits expectedTotalHits = null;
214222
int expectedTotal = 0;
215223
int expectedSuccessful = 0;
@@ -232,11 +240,10 @@ public void testMergeSearchHits() throws InterruptedException {
232240
expectedSkipped += skipped;
233241

234242
TotalHits totalHits = null;
235-
if (totalHitsRelation != null) {
236-
//TODO totalHits may overflow if each cluster reports a very high number?
243+
if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) {
237244
totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation);
238245
long previousValue = expectedTotalHits == null ? 0 : expectedTotalHits.value;
239-
expectedTotalHits = new TotalHits(previousValue + totalHits.value, totalHitsRelation);
246+
expectedTotalHits = new TotalHits(Math.min(previousValue + totalHits.value, trackTotalHitsUpTo), totalHitsRelation);
240247
}
241248

242249
final int numDocs = totalHits == null || totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value;
@@ -321,6 +328,19 @@ public void testMergeSearchHits() throws InterruptedException {
321328
}
322329
}
323330

331+
private static Tuple<Integer, TotalHits.Relation> randomTrackTotalHits() {
332+
switch(randomIntBetween(0, 2)) {
333+
case 0:
334+
return Tuple.tuple(SearchContext.TRACK_TOTAL_HITS_DISABLED, null);
335+
case 1:
336+
return Tuple.tuple(randomIntBetween(10, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
337+
case 2:
338+
return Tuple.tuple(SearchContext.TRACK_TOTAL_HITS_ACCURATE, TotalHits.Relation.EQUAL_TO);
339+
default:
340+
throw new UnsupportedOperationException();
341+
}
342+
}
343+
324344
private static SearchHit[] randomSearchHitArray(int numDocs, int numResponses, String clusterAlias, Index[] indices, float maxScore,
325345
int scoreFactor, SortField[] sortFields, PriorityQueue<SearchHit> priorityQueue) {
326346
SearchHit[] hits = new SearchHit[numDocs];

0 commit comments

Comments
 (0)