Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.RoaringDocIdSet;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
Expand Down Expand Up @@ -87,6 +90,12 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException

// Replay all documents that contain at least one top bucket (collected during the first pass).
grow(keys.size()+1);
final boolean needsScores = needsScores();
Weight weight = null;
if (needsScores) {
Query query = context.query();
weight = context.searcher().createNormalizedWeight(query, true);
}
for (LeafContext context : contexts) {
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
if (docIdSetIterator == null) {
Expand All @@ -95,7 +104,21 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException
final CompositeValuesSource.Collector collector =
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
int docID;
DocIdSetIterator scorerIt = null;
if (needsScores) {
Scorer scorer = weight.scorer(context.ctx);
// We don't need to check if the scorer is null
// since we are sure that there are documents to replay (docIdSetIterator it not empty).
scorerIt = scorer.iterator();
context.subCollector.setScorer(scorer);
}
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (needsScores) {
assert scorerIt.docID() < docID;
scorerIt.advance(docID);
// aggregations should only be replayed on matching documents
assert scorerIt.docID() == docID;
}
collector.collect(docID);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.IndexSettingsModule;
import org.joda.time.DateTimeZone;
Expand Down Expand Up @@ -1065,8 +1067,73 @@ public void testWithKeywordAndDateHistogram() throws IOException {
);
}

private void testSearchCase(Query query,
Sort sort,
public void testWithKeywordAndTopHits() throws Exception {
final List<Map<String, List<Object>>> dataset = new ArrayList<>();
dataset.addAll(
Arrays.asList(
createDocument("keyword", "a"),
createDocument("keyword", "c"),
createDocument("keyword", "a"),
createDocument("keyword", "d"),
createDocument("keyword", "c")
)
);
final Sort sort = new Sort(new SortedSetSortField("keyword", false));
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
() -> {
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
.field("keyword");
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
}, (result) -> {
assertEquals(3, result.getBuckets().size());
assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
assertEquals(2L, result.getBuckets().get(0).getDocCount());
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
assertEquals(2L, result.getBuckets().get(1).getDocCount());
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
assertEquals(1L, result.getBuckets().get(2).getDocCount());
topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 1);
assertEquals(topHits.getHits().getTotalHits(), 1L);;
}
);

testSearchCase(new MatchAllDocsQuery(), sort, dataset,
() -> {
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
.field("keyword");
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
.aggregateAfter(Collections.singletonMap("keyword", "a"))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
}, (result) -> {
assertEquals(2, result.getBuckets().size());
assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
assertEquals(2L, result.getBuckets().get(0).getDocCount());
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 2);
assertEquals(topHits.getHits().getTotalHits(), 2L);
assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
assertEquals(1L, result.getBuckets().get(1).getDocCount());
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
assertNotNull(topHits);
assertEquals(topHits.getHits().getHits().length, 1);
assertEquals(topHits.getHits().getTotalHits(), 1L);
}
);
}

private void testSearchCase(Query query, Sort sort,
List<Map<String, List<Object>>> dataset,
Supplier<CompositeAggregationBuilder> create,
Consumer<InternalComposite> verify) throws IOException {
Expand Down Expand Up @@ -1107,7 +1174,7 @@ private void executeTestCase(boolean reduced,
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
CompositeAggregationBuilder aggregationBuilder = create.get();
if (sort != null) {
CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
assertTrue(aggregator.canEarlyTerminate());
}
final InternalComposite composite;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,27 @@ protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggreg
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
}

/** Create a factory for the given aggregation builder. */

protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
}

/** Create a factory for the given aggregation builder. */
protected AggregatorFactory<?> createAggregatorFactory(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
when(searchContext.aggregations())
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
when(searchContext.query()).thenReturn(query);
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
MapperService mapperService = mapperServiceMock();
Expand Down Expand Up @@ -146,28 +157,38 @@ protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregati
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
}

protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
protected <A extends Aggregator> A createAggregator(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(aggregationBuilder, indexSearcher, indexSettings,
return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
}

protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
}

protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
}

protected <A extends Aggregator> A createAggregator(Query query,
AggregationBuilder aggregationBuilder,
IndexSearcher indexSearcher,
IndexSettings indexSettings,
MultiBucketConsumer bucketConsumer,
MappedFieldType... fieldTypes) throws IOException {
@SuppressWarnings("unchecked")
A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
.create(null, true);
return aggregator;
}
Expand Down Expand Up @@ -262,7 +283,7 @@ protected <A extends InternalAggregation, C extends Aggregator> A search(IndexSe
int maxBucket,
MappedFieldType... fieldTypes) throws IOException {
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
a.preCollection();
searcher.search(query, a);
a.postCollection();
Expand Down Expand Up @@ -310,11 +331,11 @@ protected <A extends InternalAggregation, C extends Aggregator> A searchAndReduc
Query rewritten = searcher.rewrite(query);
Weight weight = searcher.createWeight(rewritten, true, 1f);
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);

for (ShardSearcher subSearcher : subSearchers) {
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes);
C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
a.preCollection();
subSearcher.search(weight, a);
a.postCollection();
Expand Down