Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,8 @@ protected Aggregator doCreateInternal(ValuesSource rawValuesSource,
}
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, parent);
}
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ public class ChildrenToParentAggregator extends ParentJoinAggregator {
public ChildrenToParentAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, metadata);
long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,8 @@ protected Aggregator doCreateInternal(ValuesSource rawValuesSource,
}
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, children);
}
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.internal.SearchContext;

Expand Down Expand Up @@ -68,6 +68,7 @@ public ParentJoinAggregator(String name,
Query outFilter,
ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd,
boolean collectsFromSingleBucket,
Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, metadata);

Expand All @@ -81,8 +82,9 @@ public ParentJoinAggregator(String name,
this.outFilter = context.searcher().createWeight(context.searcher().rewrite(outFilter), ScoreMode.COMPLETE_NO_SCORES, 1f);
this.valuesSource = valuesSource;
boolean singleAggregator = parent == null;
collectionStrategy = singleAggregator ?
new DenseCollectionStrategy(maxOrd, context.bigArrays()) : new SparseCollectionStrategy(context.bigArrays());
collectionStrategy = singleAggregator && collectsFromSingleBucket
? new DenseCollectionStrategy(maxOrd, context.bigArrays())
: new SparseCollectionStrategy(context.bigArrays(), collectsFromSingleBucket);
}

@Override
Expand All @@ -95,19 +97,18 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), inFilter.scorerSupplier(ctx));
return new LeafBucketCollector() {
@Override
public void collect(int docId, long bucket) throws IOException {
assert bucket == 0;
public void collect(int docId, long owningBucketOrd) throws IOException {
if (parentDocs.get(docId) && globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
collectionStrategy.addGlobalOrdinal(globalOrdinal);
collectionStrategy.add(owningBucketOrd, globalOrdinal);
}
}
};
}

@Override
protected final void doPostCollection() throws IOException {
protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {
IndexReader indexReader = context().searcher().getIndexReader();
for (LeafReaderContext ctx : indexReader.leaves()) {
Scorer childDocsScorer = outFilter.scorer(ctx);
Expand Down Expand Up @@ -137,11 +138,21 @@ public int docID() {
if (liveDocs != null && liveDocs.get(docId) == false) {
continue;
}
if (globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
if (collectionStrategy.existsGlobalOrdinal(globalOrdinal)) {
collectBucket(sub, docId, 0);
if (false == globalOrdinals.advanceExact(docId)) {
continue;
}
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
/*
* Check if we contain every ordinal. It's almost certainly be
* faster to replay all the matching ordinals and filter them down
* to just those listed in ordsToCollect, but we don't have a data
* structure that maps a primitive long to a list of primitive
* longs.
*/
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty much the algorithm that we used to use. It ain't perfect, but it gets the job done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

for (long owningBucketOrd: ordsToCollect) {
if (collectionStrategy.exists(owningBucketOrd, globalOrdinal)) {
collectBucket(sub, docId, owningBucketOrd);
}
}
}
Expand All @@ -160,8 +171,8 @@ protected void doClose() {
* {@code ParentJoinAggregator#outFilter} also have the ordinal.
*/
protected interface CollectionStrategy extends Releasable {
void addGlobalOrdinal(int globalOrdinal);
boolean existsGlobalOrdinal(int globalOrdinal);
void add(long owningBucketOrd, int globalOrdinal);
boolean exists(long owningBucketOrd, int globalOrdinal);
}

/**
Expand All @@ -178,12 +189,14 @@ public DenseCollectionStrategy(long maxOrd, BigArrays bigArrays) {
}

@Override
public void addGlobalOrdinal(int globalOrdinal) {
public void add(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
ordsBits.set(globalOrdinal);
}

@Override
public boolean existsGlobalOrdinal(int globalOrdinal) {
public boolean exists(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
return ordsBits.get(globalOrdinal);
}

Expand All @@ -200,20 +213,20 @@ public void close() {
* when only some docs might match.
*/
protected class SparseCollectionStrategy implements CollectionStrategy {
private final LongHash ordsHash;
private final LongKeyedBucketOrds ordsHash;

public SparseCollectionStrategy(BigArrays bigArrays) {
ordsHash = new LongHash(1, bigArrays);
public SparseCollectionStrategy(BigArrays bigArrays, boolean collectsFromSingleBucket) {
ordsHash = LongKeyedBucketOrds.build(bigArrays, collectsFromSingleBucket);
}

@Override
public void addGlobalOrdinal(int globalOrdinal) {
ordsHash.add(globalOrdinal);
public void add(long owningBucketOrd, int globalOrdinal) {
ordsHash.add(owningBucketOrd, globalOrdinal);
}

@Override
public boolean existsGlobalOrdinal(int globalOrdinal) {
return ordsHash.find(globalOrdinal) >= 0;
public boolean exists(long owningBucketOrd, int globalOrdinal) {
return ordsHash.find(owningBucketOrd, globalOrdinal) >= 0;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public class ParentToChildrenAggregator extends ParentJoinAggregator {
public ParentToChildrenAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, metadata);
long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -313,8 +312,7 @@ private void testCaseTerms(Query query, IndexSearcher indexSearcher, Consumer<In
throws IOException {

ParentAggregationBuilder aggregationBuilder = new ParentAggregationBuilder("_name", CHILD_TYPE);
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG)
.field("number"));
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").field("number"));

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
Expand All @@ -326,9 +324,9 @@ private void testCaseTerms(Query query, IndexSearcher indexSearcher, Consumer<In
private void testCaseTermsParentTerms(Query query, IndexSearcher indexSearcher, Consumer<LongTerms> verify)
throws IOException {
AggregationBuilder aggregationBuilder =
new TermsAggregationBuilder("subvalue_terms").userValueTypeHint(ValueType.LONG).field("subNumber").
new TermsAggregationBuilder("subvalue_terms").field("subNumber").
subAggregation(new ParentAggregationBuilder("to_parent", CHILD_TYPE).
subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG).field("number")));
subAggregation(new TermsAggregationBuilder("value_terms").field("number")));

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
Expand Down Expand Up @@ -52,7 +53,10 @@
import org.elasticsearch.join.mapper.MetaJoinFieldMapper;
import org.elasticsearch.join.mapper.ParentJoinFieldMapper;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;

Expand All @@ -64,6 +68,7 @@
import java.util.Map;
import java.util.function.Consumer;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -124,12 +129,68 @@ public void testParentChild() throws IOException {
directory.close();
}

public void testParentChildAsSubAgg() throws IOException {
try (Directory directory = newDirectory()) {
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);

final Map<String, Tuple<Integer, Integer>> expectedParentChildRelations = setupIndex(indexWriter);
indexWriter.close();

try (
IndexReader indexReader = ElasticsearchDirectoryReader.wrap(
DirectoryReader.open(directory),
new ShardId(new Index("foo", "_na_"), 1)
)
) {
IndexSearcher indexSearcher = newSearcher(indexReader, false, true);

AggregationBuilder request = new TermsAggregationBuilder("t").field("kwd")
.subAggregation(
new ChildrenAggregationBuilder("children", CHILD_TYPE).subAggregation(
new MinAggregationBuilder("min").field("number")
)
);

long expectedEvenChildCount = 0;
double expectedEvenMin = Double.MAX_VALUE;
long expectedOddChildCount = 0;
double expectedOddMin = Double.MAX_VALUE;
for (Map.Entry<String, Tuple<Integer, Integer>> e : expectedParentChildRelations.entrySet()) {
if (Integer.valueOf(e.getKey().substring("parent".length())) % 2 == 0) {
expectedEvenChildCount += e.getValue().v1();
expectedEvenMin = Math.min(expectedEvenMin, e.getValue().v2());
} else {
expectedOddChildCount += e.getValue().v1();
expectedOddMin = Math.min(expectedOddMin, e.getValue().v2());
}
}
StringTerms result = search(indexSearcher, new MatchAllDocsQuery(), request, longField("number"), keywordField("kwd"));

StringTerms.Bucket evenBucket = result.getBucketByKey("even");
InternalChildren evenChildren = evenBucket.getAggregations().get("children");
InternalMin evenMin = evenChildren.getAggregations().get("min");
assertThat(evenChildren.getDocCount(), equalTo(expectedEvenChildCount));
assertThat(evenMin.getValue(), equalTo(expectedEvenMin));

if (expectedOddChildCount > 0) {
StringTerms.Bucket oddBucket = result.getBucketByKey("odd");
InternalChildren oddChildren = oddBucket.getAggregations().get("children");
InternalMin oddMin = oddChildren.getAggregations().get("min");
assertThat(oddChildren.getDocCount(), equalTo(expectedOddChildCount));
assertThat(oddMin.getValue(), equalTo(expectedOddMin));
} else {
assertNull(result.getBucketByKey("odd"));
}
}
}
}

private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter iw) throws IOException {
Map<String, Tuple<Integer, Integer>> expectedValues = new HashMap<>();
int numParents = randomIntBetween(1, 10);
for (int i = 0; i < numParents; i++) {
String parent = "parent" + i;
iw.addDocument(createParentDocument(parent));
iw.addDocument(createParentDocument(parent, i % 2 == 0 ? "even" : "odd"));
int numChildren = randomIntBetween(1, 10);
int minValue = Integer.MAX_VALUE;
for (int c = 0; c < numChildren; c++) {
Expand All @@ -142,9 +203,10 @@ private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter
return expectedValues;
}

private static List<Field> createParentDocument(String id) {
private static List<Field> createParentDocument(String id, String kwd) {
return Arrays.asList(
new StringField(IdFieldMapper.NAME, Uid.encodeId(id), Field.Store.NO),
new SortedSetDocValuesField("kwd", new BytesRef(kwd)),
new StringField("join_field", PARENT_TYPE, Field.Store.NO),
createJoinField(PARENT_TYPE, id)
);
Expand Down