Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public Double getCorrelation(String fieldX, String fieldY) {
}

/** return the value for two fields in an upper triangular matrix, regardless of row col location. */
private double getValFromUpperTriangularMatrix(Map<String, HashMap<String, Double>> map, String fieldX, String fieldY) {
static <M extends Map<String, Double>> double getValFromUpperTriangularMatrix(Map<String, M> map, String fieldX, String fieldY) {
// for the co-value to exist, one of the two (or both) fields has to be a row key
if (map.containsKey(fieldX) == false && map.containsKey(fieldY) == false) {
throw new IllegalArgumentException("neither field " + fieldX + " nor " + fieldY + " exist");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.search.aggregations.matrix.stats;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.ParsedAggregation;

import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class ParsedMatrixStats extends ParsedAggregation implements MatrixStats {

private final Map<String, Long> counts = new LinkedHashMap<>();
private final Map<String, Double> means = new HashMap<>();
private final Map<String, Double> variances = new HashMap<>();
private final Map<String, Double> skewness = new HashMap<>();
private final Map<String, Double> kurtosis = new HashMap<>();
private final Map<String, Map<String, Double>> covariances = new HashMap<>();
private final Map<String, Map<String, Double>> correlations = new HashMap<>();

@Override
public String getType() {
return MatrixStatsAggregationBuilder.NAME;
}

@Override
public long getDocCount() {
throw new UnsupportedOperationException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't render this to the REST response, but its part of the MatrixStats interface, should we either add this value to the output or remove the method from the interface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to add this field to the REST response. I can do that in another PR against master.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

}

@Override
public long getFieldCount(String field) {
if (counts.containsKey(field) == false) {
return 0;
}
return counts.get(field);
}

@Override
public double getMean(String field) {
return checkedGet(means, field);
}

@Override
public double getVariance(String field) {
return checkedGet(variances, field);
}

@Override
public double getSkewness(String field) {
return checkedGet(skewness, field);
}

@Override
public double getKurtosis(String field) {
return checkedGet(kurtosis, field);
}

@Override
public double getCovariance(String fieldX, String fieldY) {
if (fieldX.equals(fieldY)) {
return checkedGet(variances, fieldX);
}
return MatrixStatsResults.getValFromUpperTriangularMatrix(covariances, fieldX, fieldY);
}

@Override
public double getCorrelation(String fieldX, String fieldY) {
if (fieldX.equals(fieldY)) {
return 1.0;
}
return MatrixStatsResults.getValFromUpperTriangularMatrix(correlations, fieldX, fieldY);
}

@Override
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
if (counts != null && counts.isEmpty() == false) {
builder.startArray(InternalMatrixStats.Fields.FIELDS);
for (String fieldName : counts.keySet()) {
builder.startObject();
builder.field(InternalMatrixStats.Fields.NAME, fieldName);
builder.field(InternalMatrixStats.Fields.COUNT, getFieldCount(fieldName));
builder.field(InternalMatrixStats.Fields.MEAN, getMean(fieldName));
builder.field(InternalMatrixStats.Fields.VARIANCE, getVariance(fieldName));
builder.field(InternalMatrixStats.Fields.SKEWNESS, getSkewness(fieldName));
builder.field(InternalMatrixStats.Fields.KURTOSIS, getKurtosis(fieldName));
{
builder.startObject(InternalMatrixStats.Fields.COVARIANCE);
Map<String, Double> covars = covariances.get(fieldName);
if (covars != null) {
for (Map.Entry<String, Double> covar : covars.entrySet()) {
builder.field(covar.getKey(), covar.getValue());
}
}
builder.endObject();
}
{
builder.startObject(InternalMatrixStats.Fields.CORRELATION);
Map<String, Double> correls = correlations.get(fieldName);
if (correls != null) {
for (Map.Entry<String, Double> correl : correls.entrySet()) {
builder.field(correl.getKey(), correl.getValue());
}
}
builder.endObject();
}
builder.endObject();
}
builder.endArray();
}
return builder;
}

private static <T> T checkedGet(final Map<String, T> values, final String fieldName) {
if (fieldName == null) {
throw new IllegalArgumentException("field name cannot be null");
}
if (values.containsKey(fieldName) == false) {
throw new IllegalArgumentException("field " + fieldName + " does not exist");
}
return values.get(fieldName);
}

private static ObjectParser<ParsedMatrixStats, Void> PARSER =
new ObjectParser<>(ParsedMatrixStats.class.getSimpleName(), true, ParsedMatrixStats::new);
static {
declareAggregationFields(PARSER);
PARSER.declareObjectArray((matrixStats, results) -> {
for (ParsedMatrixStatsResult result : results) {
final String fieldName = result.name;
matrixStats.counts.put(fieldName, result.count);
matrixStats.means.put(fieldName, result.mean);
matrixStats.variances.put(fieldName, result.variance);
matrixStats.skewness.put(fieldName, result.skewness);
matrixStats.kurtosis.put(fieldName, result.kurtosis);
matrixStats.covariances.put(fieldName, result.covariances);
matrixStats.correlations.put(fieldName, result.correlations);
}
}, (p, c) -> ParsedMatrixStatsResult.fromXContent(p), new ParseField(InternalMatrixStats.Fields.FIELDS));
}

public static ParsedMatrixStats fromXContent(XContentParser parser, String name) throws IOException {
ParsedMatrixStats aggregation = PARSER.parse(parser, null);
aggregation.setName(name);
return aggregation;
}

static class ParsedMatrixStatsResult {

String name;
Long count;
Double mean;
Double variance;
Double skewness;
Double kurtosis;
Map<String, Double> covariances;
Map<String, Double> correlations;

private static ObjectParser<ParsedMatrixStatsResult, Void> RESULT_PARSER =
new ObjectParser<>(ParsedMatrixStatsResult.class.getSimpleName(), true, ParsedMatrixStatsResult::new);
static {
RESULT_PARSER.declareString((result, name) -> result.name = name,
new ParseField(InternalMatrixStats.Fields.NAME));
RESULT_PARSER.declareLong((result, count) -> result.count = count,
new ParseField(InternalMatrixStats.Fields.COUNT));
RESULT_PARSER.declareDouble((result, mean) -> result.mean = mean,
new ParseField(InternalMatrixStats.Fields.MEAN));
RESULT_PARSER.declareDouble((result, variance) -> result.variance = variance,
new ParseField(InternalMatrixStats.Fields.VARIANCE));
RESULT_PARSER.declareDouble((result, skewness) -> result.skewness = skewness,
new ParseField(InternalMatrixStats.Fields.SKEWNESS));
RESULT_PARSER.declareDouble((result, kurtosis) -> result.kurtosis = kurtosis,
new ParseField(InternalMatrixStats.Fields.KURTOSIS));

RESULT_PARSER.declareObject((ParsedMatrixStatsResult result, Map<String, Object> covars) -> {
result.covariances = new LinkedHashMap<>(covars.size());
for (Map.Entry<String, Object> covar : covars.entrySet()) {
result.covariances.put(covar.getKey(), mapValueAsDouble(covar.getValue()));
}
}, (p, c) -> p.mapOrdered(), new ParseField(InternalMatrixStats.Fields.COVARIANCE));

RESULT_PARSER.declareObject((ParsedMatrixStatsResult result, Map<String, Object> correls) -> {
result.correlations = new LinkedHashMap<>(correls.size());
for (Map.Entry<String, Object> correl : correls.entrySet()) {
result.correlations.put(correl.getKey(), mapValueAsDouble(correl.getValue()));
}
}, (p, c) -> p.mapOrdered(), new ParseField(InternalMatrixStats.Fields.CORRELATION));
}

private static Double mapValueAsDouble(Object value) {
if (value instanceof Double) {
return (Double) value;
}
return Double.valueOf(Objects.toString(value));
}

static ParsedMatrixStatsResult fromXContent(XContentParser parser) throws IOException {
return RESULT_PARSER.parse(parser, null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,63 @@
*/
package org.elasticsearch.search.aggregations.matrix.stats;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.xcontent.ContextParser;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class InternalMatrixStatsTests extends InternalAggregationTestCase<InternalMatrixStats> {

private String[] fields;
private boolean hasMatrixStatsResults;

@Override
public void setUp() throws Exception {
super.setUp();
hasMatrixStatsResults = frequently();
int numFields = hasMatrixStatsResults ? randomInt(128) : 0;
fields = new String[numFields];
for (int i = 0; i < numFields; i++) {
fields[i] = "field_" + i;
}
}

@Override
protected List<NamedXContentRegistry.Entry> getNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = new ArrayList<>(getDefaultNamedXContents());
ContextParser<Object, Aggregation> parser = (p, c) -> ParsedMatrixStats.fromXContent(p, (String) c);
namedXContents.add(new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(MatrixStatsAggregationBuilder.NAME), parser));
return namedXContents;
}

@Override
protected InternalMatrixStats createTestInstance(String name, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
int numFields = randomInt(128);
String[] fieldNames = new String[numFields];
double[] fieldValues = new double[numFields];
for (int i = 0; i < numFields; i++) {
fieldNames[i] = Integer.toString(i);
fieldValues[i] = randomDouble();
double[] values = new double[fields.length];
for (int i = 0; i < fields.length; i++) {
values[i] = randomDouble();
}

RunningStats runningStats = new RunningStats();
runningStats.add(fieldNames, fieldValues);
MatrixStatsResults matrixStatsResults = randomBoolean() ? new MatrixStatsResults(runningStats) : null;
return new InternalMatrixStats("_name", 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap());
runningStats.add(fields, values);
MatrixStatsResults matrixStatsResults = hasMatrixStatsResults ? new MatrixStatsResults(runningStats) : null;
return new InternalMatrixStats(name, 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap());
}

@Override
Expand Down Expand Up @@ -100,4 +127,48 @@ public void testReduceRandom() {
protected void assertReduced(InternalMatrixStats reduced, List<InternalMatrixStats> inputs) {
throw new UnsupportedOperationException();
}

@Override
protected void assertFromXContent(InternalMatrixStats expected, ParsedAggregation parsedAggregation) throws IOException {
assertTrue(parsedAggregation instanceof ParsedMatrixStats);
ParsedMatrixStats actual = (ParsedMatrixStats) parsedAggregation;

//norelease add parsing logic for doc count and enable this test once elastic/elasticsearch#24776 is merged
//assertEquals(expected.getDocCount(), actual.getDocCount());

for (String field : fields) {
assertEquals(expected.getFieldCount(field), actual.getFieldCount(field));
assertEquals(expected.getMean(field), actual.getMean(field), 0.0);
assertEquals(expected.getVariance(field), actual.getVariance(field), 0.0);
assertEquals(expected.getSkewness(field), actual.getSkewness(field), 0.0);
assertEquals(expected.getKurtosis(field), actual.getKurtosis(field), 0.0);

for (String other : fields) {
assertEquals(expected.getCovariance(field, other), actual.getCovariance(field, other), 0.0);
assertEquals(expected.getCorrelation(field, other), actual.getCorrelation(field, other), 0.0);
}
}

final String unknownField = randomAlphaOfLength(3);
final String other = randomAlphaOfLength(3);

for (MatrixStats matrix : Arrays.asList(actual)) {

// getFieldCount returns 0 for unknown fields
assertEquals(0.0, matrix.getFieldCount(unknownField), 0.0);

expectThrows(IllegalArgumentException.class, () -> matrix.getMean(unknownField));
expectThrows(IllegalArgumentException.class, () -> matrix.getVariance(unknownField));
expectThrows(IllegalArgumentException.class, () -> matrix.getSkewness(unknownField));
expectThrows(IllegalArgumentException.class, () -> matrix.getKurtosis(unknownField));

expectThrows(IllegalArgumentException.class, () -> matrix.getCovariance(unknownField, unknownField));
expectThrows(IllegalArgumentException.class, () -> matrix.getCovariance(unknownField, other));
expectThrows(IllegalArgumentException.class, () -> matrix.getCovariance(other, unknownField));

assertEquals(1.0, matrix.getCorrelation(unknownField, unknownField), 0.0);
expectThrows(IllegalArgumentException.class, () -> matrix.getCorrelation(unknownField, other));
expectThrows(IllegalArgumentException.class, () -> matrix.getCorrelation(other, unknownField));
}
}
}