Skip to content

Commit 62d6bc3

Browse files
authored
Reduce memory for big aggs run against many shards (#54758) (#55024)
This changes the behavior of aggregations when search is performed against enough shards to enable "batch reduce" mode. In this case we force always store aggregations in serialized form rather than a traditional java reference. This should shrink the memory usage of large aggregations at the cost of slightly slowing down aggregations where the coordinating node is also a data node. Because we're only doing this when there are many shards this is likely to be fairly rare. As a side effect this lets us add logs for the memory usage of the aggs buffer: ``` [2020-04-03T17:03:57,052][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1320->448] max [1320] [2020-04-03T17:03:57,089][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1328->448] max [1328] [2020-04-03T17:03:57,102][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1328->448] max [1328] [2020-04-03T17:03:57,103][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [1328->448] max [1328] [2020-04-03T17:03:57,105][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs final reduction [888] max [1328] ``` These are useful, but you need to keep some things in mind before trusting them: 1. The buffers are oversized ala Lucene's ArrayUtils. This means that we are using more space than we need, but probably not much more. 2. Before they are merged the aggregations are inflated into their traditional Java objects which *probably* take up a lot more space than the serialized form. That is, after all, the reason why we store them in serialized form in the first place. And, just because I can, here is another example of the log: ``` [2020-04-03T17:06:18,731][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528] [2020-04-03T17:06:18,750][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528] [2020-04-03T17:06:18,809][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528] [2020-04-03T17:06:18,827][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs partial reduction [147528->49176] max [147528] [2020-04-03T17:06:18,829][TRACE][o.e.a.s.SearchPhaseController] [runTask-0] aggs final reduction [98352] max [147528] ``` I got that last one by building a ten shard index with a million docs in it and running a `sum` in three layers of `terms` aggregations, all on `long` fields, and with a `batched_reduce_size` of `3`.
1 parent 850ea7c commit 62d6bc3

File tree

9 files changed

+155
-68
lines changed

9 files changed

+155
-68
lines changed

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,11 @@
1919

2020
package org.elasticsearch.action.search;
2121

22-
import java.util.ArrayList;
23-
import java.util.Arrays;
24-
import java.util.Collection;
25-
import java.util.Collections;
26-
import java.util.HashMap;
27-
import java.util.List;
28-
import java.util.Map;
29-
import java.util.function.Function;
30-
import java.util.function.IntFunction;
31-
import java.util.function.Supplier;
32-
import java.util.stream.Collectors;
22+
import com.carrotsearch.hppc.IntArrayList;
23+
import com.carrotsearch.hppc.ObjectObjectHashMap;
3324

25+
import org.apache.logging.log4j.LogManager;
26+
import org.apache.logging.log4j.Logger;
3427
import org.apache.lucene.index.Term;
3528
import org.apache.lucene.search.CollectionStatistics;
3629
import org.apache.lucene.search.FieldDoc;
@@ -44,6 +37,8 @@
4437
import org.apache.lucene.search.TotalHits.Relation;
4538
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
4639
import org.elasticsearch.common.collect.HppcMaps;
40+
import org.elasticsearch.common.io.stream.DelayableWriteable;
41+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
4742
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
4843
import org.elasticsearch.search.DocValueFormat;
4944
import org.elasticsearch.search.SearchHit;
@@ -67,16 +62,28 @@
6762
import org.elasticsearch.search.suggest.Suggest.Suggestion;
6863
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
6964

70-
import com.carrotsearch.hppc.IntArrayList;
71-
import com.carrotsearch.hppc.ObjectObjectHashMap;
65+
import java.util.ArrayList;
66+
import java.util.Arrays;
67+
import java.util.Collection;
68+
import java.util.Collections;
69+
import java.util.HashMap;
70+
import java.util.List;
71+
import java.util.Map;
72+
import java.util.function.Function;
73+
import java.util.function.IntFunction;
74+
import java.util.function.Supplier;
75+
import java.util.stream.Collectors;
7276

7377
public final class SearchPhaseController {
78+
private static final Logger logger = LogManager.getLogger(SearchPhaseController.class);
7479
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
7580

81+
private final NamedWriteableRegistry namedWriteableRegistry;
7682
private final Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;
7783

78-
public SearchPhaseController(
84+
public SearchPhaseController(NamedWriteableRegistry namedWriteableRegistry,
7985
Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder) {
86+
this.namedWriteableRegistry = namedWriteableRegistry;
8087
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
8188
}
8289

@@ -430,7 +437,8 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
430437
* @see QuerySearchResult#consumeProfileResult()
431438
*/
432439
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
433-
List<Supplier<InternalAggregations>> bufferedAggs, List<TopDocs> bufferedTopDocs,
440+
List<Supplier<InternalAggregations>> bufferedAggs,
441+
List<TopDocs> bufferedTopDocs,
434442
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
435443
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
436444
boolean performFinalReduce) {
@@ -522,7 +530,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
522530
private InternalAggregations reduceAggs(
523531
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
524532
boolean performFinalReduce,
525-
List<Supplier<InternalAggregations>> aggregationsList
533+
List<? extends Supplier<InternalAggregations>> aggregationsList
526534
) {
527535
/*
528536
* Parse the aggregations, clearing the list as we go so bits backing
@@ -617,8 +625,9 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
617625
* iff the buffer is exhausted.
618626
*/
619627
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
628+
private final NamedWriteableRegistry namedWriteableRegistry;
620629
private final SearchShardTarget[] processedShards;
621-
private final Supplier<InternalAggregations>[] aggsBuffer;
630+
private final DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer;
622631
private final TopDocs[] topDocsBuffer;
623632
private final boolean hasAggs;
624633
private final boolean hasTopDocs;
@@ -631,6 +640,8 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
631640
private final int topNSize;
632641
private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
633642
private final boolean performFinalReduce;
643+
private long aggsCurrentBufferSize;
644+
private long aggsMaxBufferSize;
634645

635646
/**
636647
* Creates a new {@link QueryPhaseResultConsumer}
@@ -641,12 +652,14 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
641652
* @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results
642653
* the buffer is used to incrementally reduce aggregation results before all shards responded.
643654
*/
644-
private QueryPhaseResultConsumer(SearchProgressListener progressListener, SearchPhaseController controller,
655+
private QueryPhaseResultConsumer(NamedWriteableRegistry namedWriteableRegistry, SearchProgressListener progressListener,
656+
SearchPhaseController controller,
645657
int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
646658
int trackTotalHitsUpTo, int topNSize,
647659
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
648660
boolean performFinalReduce) {
649661
super(expectedResultSize);
662+
this.namedWriteableRegistry = namedWriteableRegistry;
650663
if (expectedResultSize != 1 && bufferSize < 2) {
651664
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
652665
}
@@ -661,7 +674,7 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
661674
this.processedShards = new SearchShardTarget[expectedResultSize];
662675
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
663676
@SuppressWarnings("unchecked")
664-
Supplier<InternalAggregations>[] aggsBuffer = new Supplier[hasAggs ? bufferSize : 0];
677+
DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer = new DelayableWriteable.Serialized[hasAggs ? bufferSize : 0];
665678
this.aggsBuffer = aggsBuffer;
666679
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
667680
this.hasTopDocs = hasTopDocs;
@@ -684,15 +697,21 @@ public void consumeResult(SearchPhaseResult result) {
684697
private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
685698
if (querySearchResult.isNull() == false) {
686699
if (index == bufferSize) {
700+
InternalAggregations reducedAggs = null;
687701
if (hasAggs) {
688702
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
689703
for (int i = 0; i < aggsBuffer.length; i++) {
690704
aggs.add(aggsBuffer[i].get());
691705
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
692706
}
693-
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(
694-
aggs, aggReduceContextBuilder.forPartialReduction());
695-
aggsBuffer[0] = () -> reducedAggs;
707+
reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction());
708+
aggsBuffer[0] = DelayableWriteable.referencing(reducedAggs)
709+
.asSerialized(InternalAggregations::new, namedWriteableRegistry);
710+
long previousBufferSize = aggsCurrentBufferSize;
711+
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
712+
aggsCurrentBufferSize = aggsBuffer[0].ramBytesUsed();
713+
logger.trace("aggs partial reduction [{}->{}] max [{}]",
714+
previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize);
696715
}
697716
if (hasTopDocs) {
698717
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
@@ -705,12 +724,13 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
705724
index = 1;
706725
if (hasAggs || hasTopDocs) {
707726
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
708-
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0].get() : null, numReducePhases);
727+
topDocsStats.getTotalHits(), reducedAggs, numReducePhases);
709728
}
710729
}
711730
final int i = index++;
712731
if (hasAggs) {
713-
aggsBuffer[i] = querySearchResult.consumeAggs();
732+
aggsBuffer[i] = querySearchResult.consumeAggs().asSerialized(InternalAggregations::new, namedWriteableRegistry);
733+
aggsCurrentBufferSize += aggsBuffer[i].ramBytesUsed();
714734
}
715735
if (hasTopDocs) {
716736
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
@@ -723,7 +743,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
723743
}
724744

725745
private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
726-
return hasAggs ? Arrays.asList(aggsBuffer).subList(0, index) : null;
746+
return hasAggs ? Arrays.asList((Supplier<InternalAggregations>[]) aggsBuffer).subList(0, index) : null;
727747
}
728748

729749
private synchronized List<TopDocs> getRemainingTopDocs() {
@@ -732,6 +752,8 @@ private synchronized List<TopDocs> getRemainingTopDocs() {
732752

733753
@Override
734754
public ReducedQueryPhase reduce() {
755+
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
756+
logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize);
735757
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
736758
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false,
737759
aggReduceContextBuilder, performFinalReduce);
@@ -767,8 +789,8 @@ ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressL
767789
if (request.getBatchedReduceSize() < numShards) {
768790
int topNSize = getTopDocsSize(request);
769791
// only use this if there are aggs and if there are more shards than we should reduce at once
770-
return new QueryPhaseResultConsumer(listener, this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs,
771-
trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
792+
return new QueryPhaseResultConsumer(namedWriteableRegistry, listener, this, numShards, request.getBatchedReduceSize(),
793+
hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
772794
}
773795
}
774796
return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {

server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919

2020
package org.elasticsearch.common.io.stream;
2121

22-
import java.io.IOException;
23-
import java.util.function.Supplier;
24-
22+
import org.apache.lucene.util.Accountable;
23+
import org.apache.lucene.util.RamUsageEstimator;
2524
import org.elasticsearch.Version;
2625
import org.elasticsearch.common.bytes.BytesReference;
2726

27+
import java.io.IOException;
28+
import java.util.function.Supplier;
29+
2830
/**
2931
* A holder for {@link Writeable}s that can delays reading the underlying
3032
* {@linkplain Writeable} when it is read from a remote node.
@@ -43,12 +45,22 @@ public static <T extends Writeable> DelayableWriteable<T> referencing(T referenc
4345
* when {@link Supplier#get()} is called.
4446
*/
4547
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
46-
return new Delayed<>(reader, in);
48+
return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readBytesReference());
4749
}
4850

4951
private DelayableWriteable() {}
5052

51-
public abstract boolean isDelayed();
53+
/**
54+
* Returns a {@linkplain DelayableWriteable} that stores its contents
55+
* in serialized form.
56+
*/
57+
public abstract Serialized<T> asSerialized(Writeable.Reader<T> reader, NamedWriteableRegistry registry);
58+
59+
/**
60+
* {@code true} if the {@linkplain Writeable} is being stored in
61+
* serialized form, {@code false} otherwise.
62+
*/
63+
abstract boolean isSerialized();
5264

5365
private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
5466
private T reference;
@@ -59,11 +71,7 @@ private static class Referencing<T extends Writeable> extends DelayableWriteable
5971

6072
@Override
6173
public void writeTo(StreamOutput out) throws IOException {
62-
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
63-
buffer.setVersion(out.getVersion());
64-
reference.writeTo(buffer);
65-
out.writeBytesReference(buffer.bytes());
66-
}
74+
out.writeBytesReference(writeToBuffer(out.getVersion()).bytes());
6775
}
6876

6977
@Override
@@ -72,27 +80,48 @@ public T get() {
7280
}
7381

7482
@Override
75-
public boolean isDelayed() {
83+
public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
84+
try {
85+
return new Serialized<T>(reader, Version.CURRENT, registry, writeToBuffer(Version.CURRENT).bytes());
86+
} catch (IOException e) {
87+
throw new RuntimeException("unexpected error expanding aggregations", e);
88+
}
89+
}
90+
91+
@Override
92+
boolean isSerialized() {
7693
return false;
7794
}
95+
96+
private BytesStreamOutput writeToBuffer(Version version) throws IOException {
97+
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
98+
buffer.setVersion(version);
99+
reference.writeTo(buffer);
100+
return buffer;
101+
}
102+
}
78103
}
79104

80-
private static class Delayed<T extends Writeable> extends DelayableWriteable<T> {
105+
/**
106+
* A {@link Writeable} stored in serialized form.
107+
*/
108+
public static class Serialized<T extends Writeable> extends DelayableWriteable<T> implements Accountable {
81109
private final Writeable.Reader<T> reader;
82-
private final Version remoteVersion;
83-
private final BytesReference serialized;
110+
private final Version serializedAtVersion;
84111
private final NamedWriteableRegistry registry;
112+
private final BytesReference serialized;
85113

86-
Delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
114+
Serialized(Writeable.Reader<T> reader, Version serializedAtVersion,
115+
NamedWriteableRegistry registry, BytesReference serialized) throws IOException {
87116
this.reader = reader;
88-
remoteVersion = in.getVersion();
89-
serialized = in.readBytesReference();
90-
registry = in.namedWriteableRegistry();
117+
this.serializedAtVersion = serializedAtVersion;
118+
this.registry = registry;
119+
this.serialized = serialized;
91120
}
92121

93122
@Override
94123
public void writeTo(StreamOutput out) throws IOException {
95-
if (out.getVersion() == remoteVersion) {
124+
if (out.getVersion() == serializedAtVersion) {
96125
/*
97126
* If the version *does* line up we can just copy the bytes
98127
* which is good because this is how shard request caching
@@ -116,7 +145,7 @@ public T get() {
116145
try {
117146
try (StreamInput in = registry == null ?
118147
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
119-
in.setVersion(remoteVersion);
148+
in.setVersion(serializedAtVersion);
120149
return reader.read(in);
121150
}
122151
} catch (IOException e) {
@@ -125,8 +154,18 @@ public T get() {
125154
}
126155

127156
@Override
128-
public boolean isDelayed() {
157+
public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
158+
return this; // We're already serialized
159+
}
160+
161+
@Override
162+
boolean isSerialized() {
129163
return true;
130164
}
165+
166+
@Override
167+
public long ramBytesUsed() {
168+
return serialized.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 3 + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
169+
}
131170
}
132171
}

server/src/main/java/org/elasticsearch/node/Node.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,8 @@ protected Node(final Environment initialEnvironment,
585585
b.bind(MetadataCreateIndexService.class).toInstance(metadataCreateIndexService);
586586
b.bind(SearchService.class).toInstance(searchService);
587587
b.bind(SearchTransportService.class).toInstance(searchTransportService);
588-
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::aggReduceContextBuilder));
588+
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(
589+
namedWriteableRegistry, searchService::aggReduceContextBuilder));
589590
b.bind(Transport.class).toInstance(transport);
590591
b.bind(TransportService.class).toInstance(transportService);
591592
b.bind(NetworkService.class).toInstance(networkService);

server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,6 @@ public void run() throws IOException {
213213
}
214214

215215
private SearchPhaseController searchPhaseController() {
216-
return new SearchPhaseController(request -> InternalAggregationTestCase.emptyReduceContextBuilder());
216+
return new SearchPhaseController(writableRegistry(), request -> InternalAggregationTestCase.emptyReduceContextBuilder());
217217
}
218218
}

0 commit comments

Comments
 (0)