diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java index 4da8b7ca617b6..d5a91dd0e19c7 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsResults.java @@ -162,7 +162,7 @@ public Double getCorrelation(String fieldX, String fieldY) { } /** return the value for two fields in an upper triangular matrix, regardless of row col location. */ - private double getValFromUpperTriangularMatrix(Map> map, String fieldX, String fieldY) { + static > double getValFromUpperTriangularMatrix(Map map, String fieldX, String fieldY) { // for the co-value to exist, one of the two (or both) fields has to be a row key if (map.containsKey(fieldX) == false && map.containsKey(fieldY) == false) { throw new IllegalArgumentException("neither field " + fieldX + " nor " + fieldY + " exist"); diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/ParsedMatrixStats.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/ParsedMatrixStats.java new file mode 100644 index 0000000000000..62b51a8dd934b --- /dev/null +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/ParsedMatrixStats.java @@ -0,0 +1,224 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.matrix.stats; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.ParsedAggregation; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +public class ParsedMatrixStats extends ParsedAggregation implements MatrixStats { + + private final Map counts = new LinkedHashMap<>(); + private final Map means = new HashMap<>(); + private final Map variances = new HashMap<>(); + private final Map skewness = new HashMap<>(); + private final Map kurtosis = new HashMap<>(); + private final Map> covariances = new HashMap<>(); + private final Map> correlations = new HashMap<>(); + + @Override + public String getType() { + return MatrixStatsAggregationBuilder.NAME; + } + + @Override + public long getDocCount() { + throw new UnsupportedOperationException(); + } + + @Override + public long getFieldCount(String field) { + if (counts.containsKey(field) == false) { + return 0; + } + return counts.get(field); + } + + @Override + public double getMean(String field) { + return checkedGet(means, field); + } + + @Override + public double getVariance(String field) { + return checkedGet(variances, field); + } + + @Override + public double getSkewness(String field) { + return checkedGet(skewness, field); + } + + @Override + public double getKurtosis(String field) { + return checkedGet(kurtosis, field); + } + + @Override + public double getCovariance(String fieldX, String fieldY) { + if (fieldX.equals(fieldY)) { + return checkedGet(variances, fieldX); + } + return MatrixStatsResults.getValFromUpperTriangularMatrix(covariances, fieldX, fieldY); + } + + @Override + public double getCorrelation(String fieldX, String fieldY) { + if (fieldX.equals(fieldY)) { + return 1.0; + } + return MatrixStatsResults.getValFromUpperTriangularMatrix(correlations, fieldX, fieldY); + } + + @Override + protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + if (counts != null && counts.isEmpty() == false) { + builder.startArray(InternalMatrixStats.Fields.FIELDS); + for (String fieldName : counts.keySet()) { + builder.startObject(); + builder.field(InternalMatrixStats.Fields.NAME, fieldName); + builder.field(InternalMatrixStats.Fields.COUNT, getFieldCount(fieldName)); + builder.field(InternalMatrixStats.Fields.MEAN, getMean(fieldName)); + builder.field(InternalMatrixStats.Fields.VARIANCE, getVariance(fieldName)); + builder.field(InternalMatrixStats.Fields.SKEWNESS, getSkewness(fieldName)); + builder.field(InternalMatrixStats.Fields.KURTOSIS, getKurtosis(fieldName)); + { + builder.startObject(InternalMatrixStats.Fields.COVARIANCE); + Map covars = covariances.get(fieldName); + if (covars != null) { + for (Map.Entry covar : covars.entrySet()) { + builder.field(covar.getKey(), covar.getValue()); + } + } + builder.endObject(); + } + { + builder.startObject(InternalMatrixStats.Fields.CORRELATION); + Map correls = correlations.get(fieldName); + if (correls != null) { + for (Map.Entry correl : correls.entrySet()) { + builder.field(correl.getKey(), correl.getValue()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + return builder; + } + + private static T checkedGet(final Map values, final String fieldName) { + if (fieldName == null) { + throw new IllegalArgumentException("field name cannot be null"); + } + if (values.containsKey(fieldName) == false) { + throw new IllegalArgumentException("field " + fieldName + " does not exist"); + } + return values.get(fieldName); + } + + private static ObjectParser PARSER = + new ObjectParser<>(ParsedMatrixStats.class.getSimpleName(), true, ParsedMatrixStats::new); + static { + declareAggregationFields(PARSER); + PARSER.declareObjectArray((matrixStats, results) -> { + for (ParsedMatrixStatsResult result : results) { + final String fieldName = result.name; + matrixStats.counts.put(fieldName, result.count); + matrixStats.means.put(fieldName, result.mean); + matrixStats.variances.put(fieldName, result.variance); + matrixStats.skewness.put(fieldName, result.skewness); + matrixStats.kurtosis.put(fieldName, result.kurtosis); + matrixStats.covariances.put(fieldName, result.covariances); + matrixStats.correlations.put(fieldName, result.correlations); + } + }, (p, c) -> ParsedMatrixStatsResult.fromXContent(p), new ParseField(InternalMatrixStats.Fields.FIELDS)); + } + + public static ParsedMatrixStats fromXContent(XContentParser parser, String name) throws IOException { + ParsedMatrixStats aggregation = PARSER.parse(parser, null); + aggregation.setName(name); + return aggregation; + } + + static class ParsedMatrixStatsResult { + + String name; + Long count; + Double mean; + Double variance; + Double skewness; + Double kurtosis; + Map covariances; + Map correlations; + + private static ObjectParser RESULT_PARSER = + new ObjectParser<>(ParsedMatrixStatsResult.class.getSimpleName(), true, ParsedMatrixStatsResult::new); + static { + RESULT_PARSER.declareString((result, name) -> result.name = name, + new ParseField(InternalMatrixStats.Fields.NAME)); + RESULT_PARSER.declareLong((result, count) -> result.count = count, + new ParseField(InternalMatrixStats.Fields.COUNT)); + RESULT_PARSER.declareDouble((result, mean) -> result.mean = mean, + new ParseField(InternalMatrixStats.Fields.MEAN)); + RESULT_PARSER.declareDouble((result, variance) -> result.variance = variance, + new ParseField(InternalMatrixStats.Fields.VARIANCE)); + RESULT_PARSER.declareDouble((result, skewness) -> result.skewness = skewness, + new ParseField(InternalMatrixStats.Fields.SKEWNESS)); + RESULT_PARSER.declareDouble((result, kurtosis) -> result.kurtosis = kurtosis, + new ParseField(InternalMatrixStats.Fields.KURTOSIS)); + + RESULT_PARSER.declareObject((ParsedMatrixStatsResult result, Map covars) -> { + result.covariances = new LinkedHashMap<>(covars.size()); + for (Map.Entry covar : covars.entrySet()) { + result.covariances.put(covar.getKey(), mapValueAsDouble(covar.getValue())); + } + }, (p, c) -> p.mapOrdered(), new ParseField(InternalMatrixStats.Fields.COVARIANCE)); + + RESULT_PARSER.declareObject((ParsedMatrixStatsResult result, Map correls) -> { + result.correlations = new LinkedHashMap<>(correls.size()); + for (Map.Entry correl : correls.entrySet()) { + result.correlations.put(correl.getKey(), mapValueAsDouble(correl.getValue())); + } + }, (p, c) -> p.mapOrdered(), new ParseField(InternalMatrixStats.Fields.CORRELATION)); + } + + private static Double mapValueAsDouble(Object value) { + if (value instanceof Double) { + return (Double) value; + } + return Double.valueOf(Objects.toString(value)); + } + + static ParsedMatrixStatsResult fromXContent(XContentParser parser) throws IOException { + return RESULT_PARSER.parse(parser, null); + } + } +} diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java index 277006da90d50..13c67b9dbb38d 100644 --- a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java +++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java @@ -18,36 +18,63 @@ */ package org.elasticsearch.search.aggregations.matrix.stats; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.xcontent.ContextParser; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.ParsedAggregation; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.test.InternalAggregationTestCase; +import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; public class InternalMatrixStatsTests extends InternalAggregationTestCase { + private String[] fields; + private boolean hasMatrixStatsResults; + + @Override + public void setUp() throws Exception { + super.setUp(); + hasMatrixStatsResults = frequently(); + int numFields = hasMatrixStatsResults ? randomInt(128) : 0; + fields = new String[numFields]; + for (int i = 0; i < numFields; i++) { + fields[i] = "field_" + i; + } + } + + @Override + protected List getNamedXContents() { + List namedXContents = new ArrayList<>(getDefaultNamedXContents()); + ContextParser parser = (p, c) -> ParsedMatrixStats.fromXContent(p, (String) c); + namedXContents.add(new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(MatrixStatsAggregationBuilder.NAME), parser)); + return namedXContents; + } + @Override protected InternalMatrixStats createTestInstance(String name, List pipelineAggregators, Map metaData) { - int numFields = randomInt(128); - String[] fieldNames = new String[numFields]; - double[] fieldValues = new double[numFields]; - for (int i = 0; i < numFields; i++) { - fieldNames[i] = Integer.toString(i); - fieldValues[i] = randomDouble(); + double[] values = new double[fields.length]; + for (int i = 0; i < fields.length; i++) { + values[i] = randomDouble(); } + RunningStats runningStats = new RunningStats(); - runningStats.add(fieldNames, fieldValues); - MatrixStatsResults matrixStatsResults = randomBoolean() ? new MatrixStatsResults(runningStats) : null; - return new InternalMatrixStats("_name", 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap()); + runningStats.add(fields, values); + MatrixStatsResults matrixStatsResults = hasMatrixStatsResults ? new MatrixStatsResults(runningStats) : null; + return new InternalMatrixStats(name, 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap()); } @Override @@ -100,4 +127,48 @@ public void testReduceRandom() { protected void assertReduced(InternalMatrixStats reduced, List inputs) { throw new UnsupportedOperationException(); } + + @Override + protected void assertFromXContent(InternalMatrixStats expected, ParsedAggregation parsedAggregation) throws IOException { + assertTrue(parsedAggregation instanceof ParsedMatrixStats); + ParsedMatrixStats actual = (ParsedMatrixStats) parsedAggregation; + + //norelease add parsing logic for doc count and enable this test once elastic/elasticsearch#24776 is merged + //assertEquals(expected.getDocCount(), actual.getDocCount()); + + for (String field : fields) { + assertEquals(expected.getFieldCount(field), actual.getFieldCount(field)); + assertEquals(expected.getMean(field), actual.getMean(field), 0.0); + assertEquals(expected.getVariance(field), actual.getVariance(field), 0.0); + assertEquals(expected.getSkewness(field), actual.getSkewness(field), 0.0); + assertEquals(expected.getKurtosis(field), actual.getKurtosis(field), 0.0); + + for (String other : fields) { + assertEquals(expected.getCovariance(field, other), actual.getCovariance(field, other), 0.0); + assertEquals(expected.getCorrelation(field, other), actual.getCorrelation(field, other), 0.0); + } + } + + final String unknownField = randomAlphaOfLength(3); + final String other = randomAlphaOfLength(3); + + for (MatrixStats matrix : Arrays.asList(actual)) { + + // getFieldCount returns 0 for unknown fields + assertEquals(0.0, matrix.getFieldCount(unknownField), 0.0); + + expectThrows(IllegalArgumentException.class, () -> matrix.getMean(unknownField)); + expectThrows(IllegalArgumentException.class, () -> matrix.getVariance(unknownField)); + expectThrows(IllegalArgumentException.class, () -> matrix.getSkewness(unknownField)); + expectThrows(IllegalArgumentException.class, () -> matrix.getKurtosis(unknownField)); + + expectThrows(IllegalArgumentException.class, () -> matrix.getCovariance(unknownField, unknownField)); + expectThrows(IllegalArgumentException.class, () -> matrix.getCovariance(unknownField, other)); + expectThrows(IllegalArgumentException.class, () -> matrix.getCovariance(other, unknownField)); + + assertEquals(1.0, matrix.getCorrelation(unknownField, unknownField), 0.0); + expectThrows(IllegalArgumentException.class, () -> matrix.getCorrelation(unknownField, other)); + expectThrows(IllegalArgumentException.class, () -> matrix.getCorrelation(other, unknownField)); + } + } }