Skip to content

Commit 28ca127

Browse files
authored
Improve vwh's distant bucket handling (#59094)
This modifies the `variable_width_histogram`'s distant bucket handling to: 1. Properly handle integer overflows 2. Recalculate the average distance when new buckets are added on the ends. This should slow down the rate at which we build extra buckets as we build more of them.
1 parent 30be215 commit 28ca127

File tree

2 files changed

+53
-28
lines changed

2 files changed

+53
-28
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregator.java

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,23 +166,15 @@ private class MergeBucketsPhase extends CollectionPhase{
166166
public DoubleArray clusterSizes; // clusterSizes != bucketDocCounts when clusters are in the middle of a merge
167167
public int numClusters;
168168

169-
private int avgBucketDistance;
169+
private double avgBucketDistance;
170170

171171
MergeBucketsPhase(DoubleArray buffer, int bufferSize) {
172172
// Cluster the documents to reduce the number of buckets
173173
// Target shardSizes * (3/4) buckets so that there's room for more distant buckets to be added during rest of collection
174174
bucketBufferedDocs(buffer, bufferSize, shardSize * 3 / 4);
175175

176176
if(bufferSize > 1) {
177-
// Calculate the average distance between buckets
178-
// Subsequent documents will be compared with this value to determine if they should be collected into
179-
// an existing bucket or into a new bucket
180-
// This can be done in a single linear scan because buckets are sorted by centroid
181-
int sum = 0;
182-
for (int i = 0; i < numClusters - 1; i++) {
183-
sum += clusterCentroids.get(i + 1) - clusterCentroids.get(i);
184-
}
185-
avgBucketDistance = (sum / (numClusters - 1));
177+
updateAvgBucketDistance();
186178
}
187179
}
188180

@@ -194,11 +186,9 @@ private class ClusterSorter extends InPlaceMergeSorter {
194186

195187
final DoubleArray values;
196188
final long[] indexes;
197-
int length;
198189

199190
ClusterSorter(DoubleArray values, int length){
200191
this.values = values;
201-
this.length = length;
202192

203193
this.indexes = new long[length];
204194
for(int i = 0; i < indexes.length; i++){
@@ -284,7 +274,7 @@ private void bucketBufferedDocs(final DoubleArray buffer, final int bufferSize,
284274
@Override
285275
public CollectionPhase collectValue(LeafBucketCollector sub, int doc, double val) throws IOException{
286276
int bucketOrd = getNearestBucket(val);
287-
double distance = Math.abs(clusterCentroids.get(bucketOrd)- val);
277+
double distance = Math.abs(clusterCentroids.get(bucketOrd) - val);
288278
if(bucketOrd == -1 || distance > (2 * avgBucketDistance) && numClusters < shardSize) {
289279
// Make a new bucket since the document is distant from all existing buckets
290280
// TODO: (maybe) Create a new bucket for <b>all</b> distant docs and merge down to shardSize buckets at end
@@ -293,17 +283,31 @@ public CollectionPhase collectValue(LeafBucketCollector sub, int doc, double val
293283
collectBucket(sub, doc, numClusters - 1);
294284

295285
if(val > clusterCentroids.get(bucketOrd)){
296-
// Insert just ahead of bucketOrd so that the array remains sorted
286+
/*
287+
* If the new value is bigger than the nearest bucket then insert
288+
* just ahead of bucketOrd so that the array remains sorted.
289+
*/
297290
bucketOrd += 1;
298291
}
299292
moveLastCluster(bucketOrd);
293+
// We've added a new bucket so update the average distance between the buckets
294+
updateAvgBucketDistance();
300295
} else {
301296
addToCluster(bucketOrd, val);
302297
collectExistingBucket(sub, doc, bucketOrd);
298+
if (bucketOrd == 0 || bucketOrd == numClusters - 1) {
299+
// Only update average distance if the centroid of one of the end buckets is modifed.
300+
updateAvgBucketDistance();
301+
}
303302
}
304303
return this;
305304
}
306305

306+
private void updateAvgBucketDistance() {
307+
// Centroids are sorted so the average distance is the difference between the first and last.
308+
avgBucketDistance = (clusterCentroids.get(numClusters - 1) - clusterCentroids.get(0)) / (numClusters - 1);
309+
}
310+
307311
/**
308312
* Creates a new cluster with <code>value</code> and appends it to the cluster arrays
309313
*/

server/src/test/java/org/elasticsearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregatorTests.java

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.elasticsearch.search.aggregations.AggregationBuilder;
4040
import org.elasticsearch.search.aggregations.AggregationBuilders;
4141
import org.elasticsearch.search.aggregations.AggregatorTestCase;
42-
import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms;
4342
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
4443
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
4544
import org.elasticsearch.search.aggregations.metrics.InternalStats;
@@ -218,25 +217,25 @@ public void testDoubles() throws Exception {
218217
// Once the cache limit is reached, cached documents are collected into (3/4 * shard_size) buckets
219218
// A new bucket should be added when there is a document that is distant from all existing buckets
220219
public void testNewBucketCreation() throws Exception {
221-
final List<Number> dataset = Arrays.asList(-1, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 40, 30, 25, 32, 38, 80, 50, 75);
220+
final List<Number> dataset = Arrays.asList(-1, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 40, 30, 25, 32, 36, 80, 50, 75, 60);
222221
double doubleError = 1d / 10000d;
223222

224223
// Search (no reduce)
225224

226225
// Expected clusters: [ (-1), (1), (3), (5), (7), (9), (11), (13), (15), (17),
227-
// (19), (25, 30, 32), (38, 40), (50), (75, 80) ]
228-
// Corresponding keys (centroids): [ -1, 1, 3, ..., 17, 19, 29, 39, 50, 77.5]
226+
// (19), (25, 30, 32), (36, 40, 50), (60), (75, 80) ]
227+
// Corresponding keys (centroids): [ -1, 1, 3, ..., 17, 19, 29, 42, 77.5]
229228
// Note: New buckets are created for 30, 50, and 80 because they are distant from the other buckets
230-
final List<Double> keys = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 29d, 39d, 50d, 77.5d);
231-
final List<Double> mins = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 25d, 38d, 50d, 75d);
232-
final List<Double> maxes = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 32d, 40d, 50d, 80d);
233-
final List<Integer> docCounts = Arrays.asList(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 1, 2);
229+
final List<Double> keys = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 29d, 42d, 60d, 77.5d);
230+
final List<Double> mins = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 25d, 36d, 60d, 75d);
231+
final List<Double> maxes = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 32d, 50d, 60d, 80d);
232+
final List<Integer> docCounts = Arrays.asList(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 2);
234233
assert keys.size() == docCounts.size() && keys.size() == keys.size();
235234

236235
final Map<Double, Integer> expectedDocCountOnlySearch = new HashMap<>();
237236
final Map<Double, Double> expectedMinsOnlySearch = new HashMap<>();
238237
final Map<Double, Double> expectedMaxesOnlySearch = new HashMap<>();
239-
for(int i=0; i<keys.size(); i++){
238+
for(int i=0; i < keys.size(); i++){
240239
expectedDocCountOnlySearch.put(keys.get(i), docCounts.get(i));
241240
expectedMinsOnlySearch.put(keys.get(i), mins.get(i));
242241
expectedMaxesOnlySearch.put(keys.get(i), maxes.get(i));
@@ -251,6 +250,31 @@ public void testNewBucketCreation() throws Exception {
251250
long expectedDocCount = expectedDocCountOnlySearch.getOrDefault(bucket.getKey(), 0).longValue();
252251
double expectedCentroid = expectedMinsOnlySearch.getOrDefault(bucket.getKey(), 0d).doubleValue();
253252
double expectedMax = expectedMaxesOnlySearch.getOrDefault(bucket.getKey(), 0d).doubleValue();
253+
assertEquals(bucket.getKeyAsString(), expectedDocCount, bucket.getDocCount());
254+
assertEquals(bucket.getKeyAsString(), expectedCentroid, bucket.min(), doubleError);
255+
assertEquals(bucket.getKeyAsString(), expectedMax, bucket.max(), doubleError);
256+
});
257+
});
258+
259+
// Rerun the test with very large keys which can cause an overflow
260+
final Map<Double, Integer> expectedDocCountBigKeys = new HashMap<>();
261+
final Map<Double, Double> expectedMinsBigKeys = new HashMap<>();
262+
final Map<Double, Double> expectedMaxesBigKeys = new HashMap<>();
263+
for(int i=0; i< keys.size(); i++){
264+
expectedDocCountBigKeys.put(Long.MAX_VALUE * keys.get(i), docCounts.get(i));
265+
expectedMinsBigKeys.put(Long.MAX_VALUE * keys.get(i), Long.MAX_VALUE * mins.get(i));
266+
expectedMaxesBigKeys.put(Long.MAX_VALUE * keys.get(i), Long.MAX_VALUE * maxes.get(i));
267+
}
268+
269+
testSearchCase(DEFAULT_QUERY, dataset.stream().map(n -> Double.valueOf(n.doubleValue() * Long.MAX_VALUE)).collect(toList()), false,
270+
aggregation -> aggregation.field(NUMERIC_FIELD).setNumBuckets(2).setShardSize(16).setInitialBuffer(12),
271+
histogram -> {
272+
final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
273+
assertEquals(expectedDocCountOnlySearch.size(), buckets.size());
274+
buckets.forEach(bucket -> {
275+
long expectedDocCount = expectedDocCountBigKeys.getOrDefault(bucket.getKey(), 0).longValue();
276+
double expectedCentroid = expectedMinsBigKeys.getOrDefault(bucket.getKey(), 0d).doubleValue();
277+
double expectedMax = expectedMaxesBigKeys.getOrDefault(bucket.getKey(), 0d).doubleValue();
254278
assertEquals(expectedDocCount, bucket.getDocCount());
255279
assertEquals(expectedCentroid, bucket.min(), doubleError);
256280
assertEquals(expectedMax, bucket.max(), doubleError);
@@ -308,7 +332,6 @@ public void testSimpleSubAggregations() throws IOException{
308332
.setShardSize(4)
309333
.subAggregation(AggregationBuilders.stats("stats").field(NUMERIC_FIELD)),
310334
histogram -> {
311-
final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
312335
double deltaError = 1d/10000d;
313336

314337
// Expected clusters: [ (1, 2), (5), (8,9) ]
@@ -343,7 +366,6 @@ public void testComplexSubAggregations() throws IOException{
343366
.setShardSize(4)
344367
.subAggregation(new StatsAggregationBuilder("stats").field(NUMERIC_FIELD)),
345368
histogram -> {
346-
final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
347369
double deltaError = 1d / 10000d;
348370

349371
// Expected clusters: [ (0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11) ]
@@ -381,16 +403,15 @@ public void testSubAggregationReduction() throws IOException{
381403
.shardSize(2)
382404
.size(1)),
383405
histogram -> {
384-
final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
385406
double deltaError = 1d / 10000d;
386407

387408
// This is a test to make sure that the sub aggregations get reduced
388409
// This terms sub aggregation has shardSize (2) != size (1), so we will get 1 bucket only if
389410
// InternalVariableWidthHistogram reduces the sub aggregations.
390411

391-
InternalTerms terms = histogram.getBuckets().get(0).getAggregations().get("terms");
412+
LongTerms terms = histogram.getBuckets().get(0).getAggregations().get("terms");
392413
assertEquals(1L, terms.getBuckets().size(), deltaError);
393-
assertEquals(1L, ((InternalTerms.Bucket) terms.getBuckets().get(0)).getKey());
414+
assertEquals(1L, terms.getBuckets().get(0).getKey());
394415
});
395416
}
396417

0 commit comments

Comments
 (0)