diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java index 84dd870e3f06d..0f42118683a58 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.metrics.tophits; +import com.carrotsearch.hppc.LongObjectHashMap; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.LeafCollector; @@ -54,21 +55,11 @@ public class TopHitsAggregator extends MetricsAggregator { - /** Simple wrapper around a top-level collector and the current leaf collector. */ - private static class TopDocsAndLeafCollector { - final TopDocsCollector topLevelCollector; - LeafCollector leafCollector; + private final FetchPhase fetchPhase; + private final SubSearchContext subSearchContext; + private final LongObjectPagedHashMap> topDocsCollectors; - TopDocsAndLeafCollector(TopDocsCollector topLevelCollector) { - this.topLevelCollector = topLevelCollector; - } - } - - final FetchPhase fetchPhase; - final SubSearchContext subSearchContext; - final LongObjectPagedHashMap topDocsCollectors; - - public TopHitsAggregator(FetchPhase fetchPhase, SubSearchContext subSearchContext, String name, SearchContext context, + TopHitsAggregator(FetchPhase fetchPhase, SubSearchContext subSearchContext, String name, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); this.fetchPhase = fetchPhase; @@ -88,9 +79,12 @@ public boolean needsScores() { } @Override - public LeafBucketCollector getLeafCollector(final LeafReaderContext ctx, - final LeafBucketCollector sub) throws IOException { - + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + // Create leaf collectors here instead of at the aggregator level. Otherwise in case this collector get invoked + // when post collecting then we have already replaced the leaf readers on the aggregator level have already been + // replaced with the next leaf readers and then post collection pushes docids of the previous segement, which + // then causes assertions to trip or incorrect top docs to be computed. + final LongObjectHashMap leafCollectors = new LongObjectHashMap<>(1); return new LeafBucketCollectorBase(sub, null) { Scorer scorer; @@ -98,21 +92,13 @@ public LeafBucketCollector getLeafCollector(final LeafReaderContext ctx, @Override public void setScorer(Scorer scorer) throws IOException { this.scorer = scorer; - for (LongObjectPagedHashMap.Cursor cursor : topDocsCollectors) { - // Instantiate the leaf collector not in the getLeafCollector(...) method or in the constructor of this - // anonymous class. Otherwise in the case this leaf bucket collector gets invoked with post collection - // then we already have moved on to the next reader and then we may encounter assertion errors or - // incorrect results. - cursor.value.leafCollector = cursor.value.topLevelCollector.getLeafCollector(ctx); - cursor.value.leafCollector.setScorer(scorer); - } super.setScorer(scorer); } @Override public void collect(int docId, long bucket) throws IOException { - TopDocsAndLeafCollector collectors = topDocsCollectors.get(bucket); - if (collectors == null) { + TopDocsCollector topDocsCollector = topDocsCollectors.get(bucket); + if (topDocsCollector == null) { SortAndFormats sort = subSearchContext.sort(); int topN = subSearchContext.from() + subSearchContext.size(); if (sort == null) { @@ -123,31 +109,39 @@ public void collect(int docId, long bucket) throws IOException { // In the QueryPhase we don't need this protection, because it is build into the IndexSearcher, // but here we create collectors ourselves and we need prevent OOM because of crazy an offset and size. topN = Math.min(topN, subSearchContext.searcher().getIndexReader().maxDoc()); - TopDocsCollector topLevelCollector; if (sort == null) { - topLevelCollector = TopScoreDocCollector.create(topN); + topDocsCollector = TopScoreDocCollector.create(topN); } else { - topLevelCollector = TopFieldCollector.create(sort.sort, topN, true, subSearchContext.trackScores(), - subSearchContext.trackScores()); + topDocsCollector = TopFieldCollector.create(sort.sort, topN, true, subSearchContext.trackScores(), + subSearchContext.trackScores()); + } + topDocsCollectors.put(bucket, topDocsCollector); + } + + final LeafCollector leafCollector; + final int key = leafCollectors.indexOf(bucket); + if (key < 0) { + leafCollector = topDocsCollector.getLeafCollector(ctx); + if (scorer != null) { + leafCollector.setScorer(scorer); } - collectors = new TopDocsAndLeafCollector(topLevelCollector); - collectors.leafCollector = collectors.topLevelCollector.getLeafCollector(ctx); - collectors.leafCollector.setScorer(scorer); - topDocsCollectors.put(bucket, collectors); + leafCollectors.indexInsert(key, bucket, leafCollector); + } else { + leafCollector = leafCollectors.indexGet(key); } - collectors.leafCollector.collect(docId); + leafCollector.collect(docId); } }; } @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { - TopDocsAndLeafCollector topDocsCollector = topDocsCollectors.get(owningBucketOrdinal); + TopDocsCollector topDocsCollector = topDocsCollectors.get(owningBucketOrdinal); final InternalTopHits topHits; if (topDocsCollector == null) { topHits = buildEmptyAggregation(); } else { - TopDocs topDocs = topDocsCollector.topLevelCollector.topDocs(); + TopDocs topDocs = topDocsCollector.topDocs(); if (subSearchContext.sort() == null) { for (RescoreContext ctx : context().rescore()) { try { diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorFactory.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorFactory.java index 6a41cc97f8ec5..09c26b169e528 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorFactory.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregatorFactory.java @@ -51,7 +51,7 @@ public class TopHitsAggregatorFactory extends AggregatorFactory scriptFields; private final FetchSourceContext fetchSourceContext; - public TopHitsAggregatorFactory(String name, int from, int size, boolean explain, boolean version, boolean trackScores, + TopHitsAggregatorFactory(String name, int from, int size, boolean explain, boolean version, boolean trackScores, Optional sort, HighlightBuilder highlightBuilder, StoredFieldsContext storedFieldsContext, List docValueFields, List scriptFields, FetchSourceContext fetchSourceContext, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactories, Map metaData) diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java index d648146a47208..3822455b83c3a 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java @@ -703,7 +703,6 @@ public void testTrackScores() throws Exception { } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/26738") public void testTopHitsInNestedSimple() throws Exception { SearchResponse searchResponse = client().prepareSearch("articles") .setQuery(matchQuery("title", "title"))