Skip to content

Commit 3409373

Browse files
committed
Added unit tests for MatrixStatsAggregator
1 parent b2ccb6b commit 3409373

File tree

5 files changed

+108
-5
lines changed

5 files changed

+108
-5
lines changed

modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ public double getCorrelation(String fieldX, String fieldY) {
139139
return results.getCorrelation(fieldX, fieldY);
140140
}
141141

142+
RunningStats getStats() {
143+
return stats;
144+
}
145+
142146
MatrixStatsResults getResults() {
143147
return results;
144148
}

modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
/**
4242
* Metric Aggregation for computing the pearson product correlation coefficient between multiple fields
4343
**/
44-
public class MatrixStatsAggregator extends MetricsAggregator {
44+
final class MatrixStatsAggregator extends MetricsAggregator {
4545
/** Multiple ValuesSource with field names */
46-
final NumericMultiValuesSource valuesSources;
46+
private final NumericMultiValuesSource valuesSources;
4747

4848
/** array of descriptive stats, per shard, needed to compute the correlation */
4949
ObjectArray<RunningStats> stats;
5050

51-
public MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
51+
MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
5252
Aggregator parent, MultiValueMode multiValueMode, List<PipelineAggregator> pipelineAggregators,
5353
Map<String,Object> metaData) throws IOException {
5454
super(name, context, parent, pipelineAggregators, metaData);

modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
import java.util.List;
3333
import java.util.Map;
3434

35-
public class MatrixStatsAggregatorFactory
35+
final class MatrixStatsAggregatorFactory
3636
extends MultiValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {
3737

3838
private final MultiValueMode multiValueMode;
3939

40-
public MatrixStatsAggregatorFactory(String name,
40+
MatrixStatsAggregatorFactory(String name,
4141
Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, MultiValueMode multiValueMode,
4242
SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
4343
Map<String, Object> metaData) throws IOException {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.search.aggregations.matrix.stats;
20+
21+
import org.apache.lucene.document.Document;
22+
import org.apache.lucene.document.Field;
23+
import org.apache.lucene.document.SortedNumericDocValuesField;
24+
import org.apache.lucene.document.StringField;
25+
import org.apache.lucene.index.IndexReader;
26+
import org.apache.lucene.index.RandomIndexWriter;
27+
import org.apache.lucene.search.IndexSearcher;
28+
import org.apache.lucene.search.MatchAllDocsQuery;
29+
import org.apache.lucene.store.Directory;
30+
import org.apache.lucene.util.NumericUtils;
31+
import org.elasticsearch.index.mapper.MappedFieldType;
32+
import org.elasticsearch.index.mapper.NumberFieldMapper;
33+
import org.elasticsearch.search.aggregations.AggregatorTestCase;
34+
35+
import java.util.Arrays;
36+
import java.util.Collections;
37+
38+
public class MatrixStatsAggregatorTests extends AggregatorTestCase {
39+
40+
public void testNoData() throws Exception {
41+
MappedFieldType ft =
42+
new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
43+
ft.setName("field");
44+
45+
try (Directory directory = newDirectory();
46+
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
47+
if (randomBoolean()) {
48+
indexWriter.addDocument(Collections.singleton(new StringField("another_field", "value", Field.Store.NO)));
49+
}
50+
try (IndexReader reader = indexWriter.getReader()) {
51+
IndexSearcher searcher = new IndexSearcher(reader);
52+
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
53+
.fields(Collections.singletonList("field"));
54+
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ft);
55+
assertNull(stats.getStats());
56+
}
57+
}
58+
}
59+
60+
public void testTwoFields() throws Exception {
61+
String fieldA = "a";
62+
MappedFieldType ftA = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
63+
ftA.setName(fieldA);
64+
String fieldB = "b";
65+
MappedFieldType ftB = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
66+
ftB.setName(fieldB);
67+
68+
try (Directory directory = newDirectory();
69+
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
70+
71+
int numDocs = scaledRandomIntBetween(8192, 16384);
72+
Double[] fieldAValues = new Double[numDocs];
73+
Double[] fieldBValues = new Double[numDocs];
74+
for (int docId = 0; docId < numDocs; docId++) {
75+
Document document = new Document();
76+
fieldAValues[docId] = randomDouble();
77+
document.add(new SortedNumericDocValuesField(fieldA, NumericUtils.doubleToSortableLong(fieldAValues[docId])));
78+
79+
fieldBValues[docId] = randomDouble();
80+
document.add(new SortedNumericDocValuesField(fieldB, NumericUtils.doubleToSortableLong(fieldBValues[docId])));
81+
indexWriter.addDocument(document);
82+
}
83+
84+
MultiPassStats multiPassStats = new MultiPassStats(fieldA, fieldB);
85+
multiPassStats.computeStats(Arrays.asList(fieldAValues), Arrays.asList(fieldBValues));
86+
try (IndexReader reader = indexWriter.getReader()) {
87+
IndexSearcher searcher = new IndexSearcher(reader);
88+
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
89+
.fields(Arrays.asList(fieldA, fieldB));
90+
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ftA, ftB);
91+
multiPassStats.assertNearlyEqual(new MatrixStatsResults(stats.getStats()));
92+
}
93+
}
94+
}
95+
96+
}

test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggreg
110110

111111
QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);
112112
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
113+
for (MappedFieldType fieldType : fieldTypes) {
114+
when(searchContext.smartNameFieldType(fieldType.name())).thenReturn(fieldType);
115+
}
113116

114117
return aggregationBuilder.build(searchContext, null);
115118
}

0 commit comments

Comments
 (0)