From 29634af68b605a0a7c138364713b687d30fe65a2 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Wed, 25 Sep 2019 17:40:33 +0200 Subject: [PATCH 1/2] Implement evaluation API for multiclass classification problem --- .../MlEvaluationNamedXContentProvider.java | 15 +- .../classification/Classification.java | 132 +++++++++ .../MulticlassConfusionMatrixMetric.java | 164 +++++++++++ .../client/MachineLearningIT.java | 160 +++++++--- .../client/RestHighLevelClientTests.java | 14 +- .../classification/ClassificationTests.java | 64 ++++ ...classConfusionMatrixMetricResultTests.java | 74 +++++ .../MulticlassConfusionMatrixMetricTests.java | 50 ++++ .../MlEvaluationNamedXContentProvider.java | 16 + .../classification/Classification.java | 172 +++++++++++ .../classification/ClassificationMetric.java | 30 ++ .../MulticlassConfusionMatrix.java | 276 ++++++++++++++++++ .../classification/ClassificationTests.java | 222 ++++++++++++++ .../MulticlassConfusionMatrixResultTests.java | 60 ++++ .../MulticlassConfusionMatrixTests.java | 187 ++++++++++++ .../ml/qa/ml-with-security/build.gradle | 3 + .../ml/integration/EvaluateDataFrameIT.java | 137 +++++++++ .../test/ml/evaluate_data_frame.yml | 118 ++++++++ 18 files changed, 1853 insertions(+), 41 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index a28c498b1d5af..dca644b663e0a 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -18,7 +18,9 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -41,6 +43,7 @@ public List getNamedXContentParsers() { // Evaluations new NamedXContentRegistry.Entry( Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent), new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent), // Evaluation metrics new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), @@ -48,6 +51,10 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(MulticlassConfusionMatrixMetric.NAME), + MulticlassConfusionMatrixMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent), new NamedXContentRegistry.Entry( @@ -60,10 +67,14 @@ EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMe new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent), + EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, + new ParseField(MulticlassConfusionMatrixMetric.NAME), + MulticlassConfusionMatrixMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); + EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent)); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java new file mode 100644 index 0000000000000..d7466fcc023b5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java @@ -0,0 +1,132 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of classification results. + */ +public class Classification implements Evaluation { + + public static final String NAME = "classification"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, true, a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS); + } + + public static Classification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Classification(String actualField, String predictedField) { + this(actualField, predictedField, (List)null); + } + + public Classification(String actualField, String predictedField, EvaluationMetric... metrics) { + this(actualField, predictedField, Arrays.asList(metrics)); + } + + public Classification(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = Objects.requireNonNull(actualField); + this.predictedField = Objects.requireNonNull(predictedField); + if (metrics != null) { + metrics.sort(Comparator.comparing(EvaluationMetric::getName)); + } + this.metrics = metrics; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + if (metrics != null) { + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + } + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Classification that = (Classification) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java new file mode 100644 index 0000000000000..ba09d8bb202af --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -0,0 +1,164 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Calculates the multiclass confusion matrix. + */ +public class MulticlassConfusionMatrixMetric implements EvaluationMetric { + + public static final String NAME = "multiclass_confusion_matrix"; + + public static final ParseField SIZE = new ParseField("size"); + + private static final ConstructingObjectParser PARSER = createParser(); + + private static ConstructingObjectParser createParser() { + ConstructingObjectParser parser = + new ConstructingObjectParser<>(NAME, true, args -> new MulticlassConfusionMatrixMetric((Integer) args[0])); + parser.declareInt(optionalConstructorArg(), SIZE); + return parser; + } + + public static MulticlassConfusionMatrixMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Integer size; + + public MulticlassConfusionMatrixMetric() { + this(null); + } + + public MulticlassConfusionMatrixMetric(@Nullable Integer size) { + this.size = size; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (size != null) { + builder.field(SIZE.getPreferredName(), size); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MulticlassConfusionMatrixMetric that = (MulticlassConfusionMatrixMetric) o; + return Objects.equals(this.size, that.size); + } + + @Override + public int hashCode() { + return Objects.hash(size); + } + + public static class Result implements EvaluationMetric.Result { + + private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("other_classes_count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + + static { + PARSER.declareObject( + constructorArg(), + (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), + CONFUSION_MATRIX); + PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + // Immutable + private final Map> confusionMatrix; + private final long otherClassesCount; + + public Result(Map> confusionMatrix, long otherClassesCount) { + this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); + this.otherClassesCount = otherClassesCount; + } + + @Override + public String getMetricName() { + return NAME; + } + + public Map> getConfusionMatrix() { + return confusionMatrix; + } + + public long getOtherClassesCount() { + return otherClassesCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.confusionMatrix, that.confusionMatrix) + && this.otherClassesCount == that.otherClassesCount; + } + + @Override + public int hashCode() { + return Objects.hash(confusionMatrix, otherClassesCount); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 9d1e04eb56309..bda4de1158928 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -1573,19 +1575,19 @@ public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() { public void testEvaluateDataFrame_BinarySoftClassification() throws IOException { String indexName = "evaluate-test-index"; - createIndex(indexName, mappingForClassification()); + createIndex(indexName, mappingForSoftClassification()); BulkRequest bulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, "blue", false, 0.1)) // #0 - .add(docForClassification(indexName, "blue", false, 0.2)) // #1 - .add(docForClassification(indexName, "blue", false, 0.3)) // #2 - .add(docForClassification(indexName, "blue", false, 0.4)) // #3 - .add(docForClassification(indexName, "blue", false, 0.7)) // #4 - .add(docForClassification(indexName, "blue", true, 0.2)) // #5 - .add(docForClassification(indexName, "green", true, 0.3)) // #6 - .add(docForClassification(indexName, "green", true, 0.4)) // #7 - .add(docForClassification(indexName, "green", true, 0.8)) // #8 - .add(docForClassification(indexName, "green", true, 0.9)); // #9 + .add(docForSoftClassification(indexName, "blue", false, 0.1)) // #0 + .add(docForSoftClassification(indexName, "blue", false, 0.2)) // #1 + .add(docForSoftClassification(indexName, "blue", false, 0.3)) // #2 + .add(docForSoftClassification(indexName, "blue", false, 0.4)) // #3 + .add(docForSoftClassification(indexName, "blue", false, 0.7)) // #4 + .add(docForSoftClassification(indexName, "blue", true, 0.2)) // #5 + .add(docForSoftClassification(indexName, "green", true, 0.3)) // #6 + .add(docForSoftClassification(indexName, "green", true, 0.4)) // #7 + .add(docForSoftClassification(indexName, "green", true, 0.8)) // #8 + .add(docForSoftClassification(indexName, "green", true, 0.9)); // #9 highLevelClient().bulk(bulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1647,19 +1649,19 @@ public void testEvaluateDataFrame_BinarySoftClassification() throws IOException public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException { String indexName = "evaluate-with-query-test-index"; - createIndex(indexName, mappingForClassification()); + createIndex(indexName, mappingForSoftClassification()); BulkRequest bulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, "blue", true, 1.0)) // #0 - .add(docForClassification(indexName, "blue", true, 1.0)) // #1 - .add(docForClassification(indexName, "blue", true, 1.0)) // #2 - .add(docForClassification(indexName, "blue", true, 1.0)) // #3 - .add(docForClassification(indexName, "blue", true, 0.0)) // #4 - .add(docForClassification(indexName, "blue", true, 0.0)) // #5 - .add(docForClassification(indexName, "green", true, 0.0)) // #6 - .add(docForClassification(indexName, "green", true, 0.0)) // #7 - .add(docForClassification(indexName, "green", true, 0.0)) // #8 - .add(docForClassification(indexName, "green", true, 1.0)); // #9 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #0 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #1 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #2 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #3 + .add(docForSoftClassification(indexName, "blue", true, 0.0)) // #4 + .add(docForSoftClassification(indexName, "blue", true, 0.0)) // #5 + .add(docForSoftClassification(indexName, "green", true, 0.0)) // #6 + .add(docForSoftClassification(indexName, "green", true, 0.0)) // #7 + .add(docForSoftClassification(indexName, "green", true, 0.0)) // #8 + .add(docForSoftClassification(indexName, "green", true, 1.0)); // #9 highLevelClient().bulk(bulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1722,6 +1724,74 @@ public void testEvaluateDataFrame_Regression() throws IOException { assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9)); } + public void testEvaluateDataFrame_Classification() throws IOException { + String indexName = "evaluate-classification-test-index"; + createIndex(indexName, mappingForClassification()); + BulkRequest regressionBulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForClassification(indexName, "cat", "cat")) + .add(docForClassification(indexName, "cat", "cat")) + .add(docForClassification(indexName, "cat", "cat")) + .add(docForClassification(indexName, "cat", "dog")) + .add(docForClassification(indexName, "cat", "fish")) + .add(docForClassification(indexName, "dog", "cat")) + .add(docForClassification(indexName, "dog", "dog")) + .add(docForClassification(indexName, "dog", "dog")) + .add(docForClassification(indexName, "dog", "dog")) + .add(docForClassification(indexName, "horse", "cat")); + highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + + { // No size provided for MulticlassConfusionMatrixMetric, default used instead + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + null, + new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + MulticlassConfusionMatrixMetric.Result mcmResult = + evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); + assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); + assertThat( + mcmResult.getConfusionMatrix(), + equalTo( + Map.of( + "cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L), + "dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L), + "horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L)))); + assertThat(mcmResult.getOtherClassesCount(), equalTo(0L)); + } + { // Explicit size provided for MulticlassConfusionMatrixMetric metric + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + null, + new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2))); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + MulticlassConfusionMatrixMetric.Result mcmResult = + evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); + assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); + assertThat( + mcmResult.getConfusionMatrix(), + equalTo( + Map.of( + "cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L), + "dog", Map.of("cat", 1L, "dog", 3L)))); + assertThat(mcmResult.getOtherClassesCount(), equalTo(1L)); + } + } + private static XContentBuilder defaultMappingForTest() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") @@ -1739,7 +1809,7 @@ private static XContentBuilder defaultMappingForTest() throws IOException { private static final String actualField = "label"; private static final String probabilityField = "p"; - private static XContentBuilder mappingForClassification() throws IOException { + private static XContentBuilder mappingForSoftClassification() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") .startObject(datasetField) @@ -1755,26 +1825,48 @@ private static XContentBuilder mappingForClassification() throws IOException { .endObject(); } - private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) { + private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) { return new IndexRequest() .index(indexName) .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p); } + private static final String actualClassField = "actual_class"; + private static final String predictedClassField = "predicted_class"; + + private static XContentBuilder mappingForClassification() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualClassField) + .field("type", "keyword") + .endObject() + .startObject(predictedClassField) + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass); + } + private static final String actualRegression = "regression_actual"; private static final String probabilityRegression = "regression_prob"; private static XContentBuilder mappingForRegression() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") - .startObject(actualRegression) - .field("type", "double") - .endObject() - .startObject(probabilityRegression) - .field("type", "double") - .endObject() + .startObject(actualRegression) + .field("type", "double") + .endObject() + .startObject(probabilityRegression) + .field("type", "double") + .endObject() .endObject() - .endObject(); + .endObject(); } private static IndexRequest docForRegression(String indexName, double act, double p) { @@ -1789,11 +1881,11 @@ private void createIndex(String indexName, XContentBuilder mapping) throws IOExc public void testEstimateMemoryUsage() throws IOException { String indexName = "estimate-test-index"; - createIndex(indexName, mappingForClassification()); + createIndex(indexName, mappingForSoftClassification()); BulkRequest bulk1 = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < 10; ++i) { - bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); + bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); } highLevelClient().bulk(bulk1, RequestOptions.DEFAULT); @@ -1819,7 +1911,7 @@ public void testEstimateMemoryUsage() throws IOException { BulkRequest bulk2 = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 10; i < 100; ++i) { - bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); + bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); } highLevelClient().bulk(bulk2, RequestOptions.DEFAULT); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 7641dd3032c83..b5394e5dcbdf3 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -57,7 +57,9 @@ import org.elasticsearch.client.ilm.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -681,7 +683,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(41, namedXContents.size()); + assertEquals(44, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -720,22 +722,24 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); - assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); - assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME)); + assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, + MulticlassConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME, RSquaredMetric.NAME)); - assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, + MulticlassConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME, RSquaredMetric.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java new file mode 100644 index 0000000000000..a72b483518cb2 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.function.Predicate; + +public class ClassificationTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Classification createRandom() { + return new Classification( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric())); + } + + @Override + protected Classification createTestInstance() { + return createRandom(); + } + + @Override + protected Classification doParseInstance(XContentParser parser) throws IOException { + return Classification.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java new file mode 100644 index 0000000000000..800a2cf7b9836 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MulticlassConfusionMatrixMetric.Result createTestInstance() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + Map> confusionMatrix = new TreeMap<>(); + for (int i = 0; i < numClasses; i++) { + Map row = new TreeMap<>(); + confusionMatrix.put(classNames.get(i), row); + for (int j = 0; j < numClasses; j++) { + if (randomBoolean()) { + row.put(classNames.get(i), randomNonNegativeLong()); + } + } + } + long otherClassesCount = randomNonNegativeLong(); + return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount); + } + + @Override + protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrixMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java new file mode 100644 index 0000000000000..f4de12796f087 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MulticlassConfusionMatrixMetric createTestInstance() { + Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null; + return new MulticlassConfusionMatrixMetric(size); + } + + @Override + protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrixMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index a2aa8e74918ac..8036c5ab8955b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -8,7 +8,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric; @@ -32,6 +35,7 @@ public List getNamedXContentParsers() { // Evaluations namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent)); // Soft classification metrics @@ -41,6 +45,10 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, ConfusionMatrix::fromXContent)); + // Classification metrics + namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME, + MulticlassConfusionMatrix::fromXContent)); + // Regression metrics namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent)); @@ -54,6 +62,8 @@ public List getNamedWriteables() { // Evaluations namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(), + Classification::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new)); // Evaluation Metrics @@ -65,6 +75,9 @@ public List getNamedWriteables() { Recall::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, + MulticlassConfusionMatrix.NAME.getPreferredName(), + MulticlassConfusionMatrix::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError::new)); @@ -79,6 +92,9 @@ public List getNamedWriteables() { ScoreByThresholdResult::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + MulticlassConfusionMatrix.NAME.getPreferredName(), + MulticlassConfusionMatrix.Result::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError.Result::new)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java new file mode 100644 index 0000000000000..a90de52ea15a9 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of classification results. + */ +public class Classification implements Evaluation { + + public static final ParseField NAME = new ParseField("classification"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS); + } + + public static Classification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Classification(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); + this.metrics = initMetrics(metrics); + } + + public Classification(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedField = in.readString(); + this.metrics = in.readNamedWriteableList(ClassificationMetric.class); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics); + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName)); + return metrics; + } + + private static List defaultMetrics() { + return Arrays.asList(new MulticlassConfusionMatrix()); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public List getMetrics() { + return metrics; + } + + @Override + public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { + ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); + SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder); + for (ClassificationMetric metric : metrics) { + List aggs = metric.aggs(actualField, predictedField); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + @Override + public void process(SearchResponse searchResponse) { + ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); + if (searchResponse.getHits().getTotalHits().value == 0) { + throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField); + } + for (ClassificationMetric metric : metrics) { + metric.process(searchResponse.getAggregations()); + } + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + builder.startObject(METRICS.getPreferredName()); + for (ClassificationMetric metric : metrics) { + builder.field(metric.getWriteableName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Classification that = (Classification) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java new file mode 100644 index 0000000000000..220942a4838a5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; + +import java.util.List; + +public interface ClassificationMetric extends EvaluationMetric { + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual value + * @param predictedField the field that stores the predicted value + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, String predictedField); + + /** + * Processes given aggregations as a step towards computing result + * @param aggs aggregations from {@link SearchResponse} + */ + void process(Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java new file mode 100644 index 0000000000000..f4e0c723da5d4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -0,0 +1,276 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * {@link MulticlassConfusionMatrix} is a metric that answers the question: + * "How many examples belonging to class X were classified as Y by the classifier?" + * for all the possible class pairs {X, Y}. + */ +public class MulticlassConfusionMatrix implements ClassificationMetric { + + public static final ParseField NAME = new ParseField("multiclass_confusion_matrix"); + + public static final ParseField SIZE = new ParseField("size"); + + private static final ConstructingObjectParser PARSER = createParser(); + + private static ConstructingObjectParser createParser() { + ConstructingObjectParser parser = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0])); + parser.declareInt(optionalConstructorArg(), SIZE); + return parser; + } + + public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; + private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; + private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; + private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; + private static final String OTHER_BUCKET_KEY = "_other_"; + private static final int DEFAULT_SIZE = 10; + private static final int MAX_SIZE = 1000; + + private final int size; + private List topActualClassNames; + private Result result; + + public MulticlassConfusionMatrix() { + this((Integer) null); + } + + public MulticlassConfusionMatrix(@Nullable Integer size) { + if (size != null && (size <= 0 || size > MAX_SIZE)) { + throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE); + } + this.size = size != null ? size : DEFAULT_SIZE; + } + + public MulticlassConfusionMatrix(StreamInput in) throws IOException { + this.size = in.readVInt(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public int getSize() { + return size; + } + + @Override + public final List aggs(String actualField, String predictedField) { + if (topActualClassNames == null) { + return List.of( + AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) + .field(actualField) + .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) + .size(size)); + } else if (result == null) { + KeyedFilter[] keyedFilters = + topActualClassNames.stream() + .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .toArray(KeyedFilter[]::new); + return List.of( + AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) + .field(actualField), + AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) + .field(actualField) + .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) + .size(size) + .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters) + .otherBucket(true) + .otherBucketKey(OTHER_BUCKET_KEY))); + } else { + return List.of(); + } + } + + @Override + public void process(Aggregations aggs) { + if (aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null && topActualClassNames == null) { + Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); + topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList()); + } + if (aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null && result == null) { + Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); + Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); + Map> counts = new TreeMap<>(); + for (Terms.Bucket bucket : termsAgg.getBuckets()) { + String actualClass = bucket.getKeyAsString(); + Map subCounts = new TreeMap<>(); + counts.put(actualClass, subCounts); + Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); + for (Filters.Bucket subBucket : subAgg.getBuckets()) { + String predictedClass = subBucket.getKeyAsString(); + Long docCount = subBucket.getDocCount(); + if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) { + subCounts.put(predictedClass, docCount); + } + } + } + result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size); + } + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(SIZE.getPreferredName(), size); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o; + return Objects.equals(this.size, that.size); + } + + @Override + public int hashCode() { + return Objects.hash(size); + } + + public static class Result implements EvaluationMetricResult { + + private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("other_classes_count"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + + static { + PARSER.declareObject( + constructorArg(), + (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), + CONFUSION_MATRIX); + PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + // Immutable + private final Map> confusionMatrix; + private final long otherClassesCount; + + public Result(Map> confusionMatrix, long otherClassesCount) { + this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); + this.otherClassesCount = otherClassesCount; + } + + public Result(StreamInput in) throws IOException { + this.confusionMatrix = Collections.unmodifiableMap( + in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong))); + this.otherClassesCount = in.readLong(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + public Map> getConfusionMatrix() { + return confusionMatrix; + } + + public long getOtherClassesCount() { + return otherClassesCount; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap( + confusionMatrix, + StreamOutput::writeString, + (out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong)); + out.writeLong(otherClassesCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.confusionMatrix, that.confusionMatrix) + && this.otherClassesCount == that.otherClassesCount; + } + + @Override + public int hashCode() { + return Objects.hash(confusionMatrix, otherClassesCount); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java new file mode 100644 index 0000000000000..fb5752f398bbc --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -0,0 +1,222 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ClassificationTests extends AbstractSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Classification createRandom() { + return new Classification( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom())); + } + + @Override + protected Classification doParseInstance(XContentParser parser) throws IOException { + return Classification.fromXContent(parser); + } + + @Override + protected Classification createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Classification::new; + } + + public void testConstructor_GivenEmptyMetrics() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", "bar", Collections.emptyList())); + assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); + } + + public void testBuildSearch() { + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter(QueryBuilders.existsQuery("pred")) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + + Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix())); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); + } + + public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() { + ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2); + ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3); + ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4); + ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5); + + Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4)); + assertThat(metric1.getResult(), isEmpty()); + assertThat(metric2.getResult(), isEmpty()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isEmpty()); + assertThat(metric2.getResult(), isEmpty()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isEmpty()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isPresent()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isPresent()); + assertThat(metric4.getResult(), isPresent()); + assertThat(evaluation.hasAllResults(), is(true)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isPresent()); + assertThat(metric4.getResult(), isPresent()); + assertThat(evaluation.hasAllResults(), is(true)); + } + + private static SearchResponse mockSearchResponseWithNonZeroTotalHits() { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits hits = new SearchHits(SearchHits.EMPTY, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + + /** + * Metric which iterates through its steps in {@link #process} method. + * Number of steps is configurable. + * Upon reaching the last step, the result is produced. + */ + private static class FakeClassificationMetric implements ClassificationMetric { + + private final String name; + private final int numSteps; + private int currentStepIndex; + private EvaluationMetricResult result; + + FakeClassificationMetric(String name, int numSteps) { + this.name = name; + this.numSteps = numSteps; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getWriteableName() { + return name; + } + + @Override + public List aggs(String actualField, String predictedField) { + return List.of(); + } + + @Override + public void process(Aggregations aggs) { + if (result != null) { + return; + } + currentStepIndex++; + if (currentStepIndex == numSteps) { + // This is the last step, time to write evaluation result + result = mock(EvaluationMetricResult.class); + } + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public void writeTo(StreamOutput out) { + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java new file mode 100644 index 0000000000000..b99e45ad7e013 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { + + @Override + protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrix.Result.fromXContent(parser); + } + + @Override + protected MulticlassConfusionMatrix.Result createTestInstance() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + Map> confusionMatrix = new TreeMap<>(); + for (int i = 0; i < numClasses; i++) { + Map row = new TreeMap<>(); + confusionMatrix.put(classNames.get(i), row); + for (int j = 0; j < numClasses; j++) { + if (randomBoolean()) { + row.put(classNames.get(i), randomNonNegativeLong()); + } + } + } + long otherClassesCount = randomNonNegativeLong(); + return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount); + } + + @Override + protected Writeable.Reader instanceReader() { + return MulticlassConfusionMatrix.Result::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java new file mode 100644 index 0000000000000..a4e989bce898a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase { + + @Override + protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrix.fromXContent(parser); + } + + @Override + protected MulticlassConfusionMatrix createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MulticlassConfusionMatrix::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static MulticlassConfusionMatrix createRandom() { + Integer size = randomBoolean() ? null : randomIntBetween(1, 1000); + return new MulticlassConfusionMatrix(size); + } + + public void testConstructor_SizeValidationFailures() { + { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1)); + assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); + } + { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0)); + assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); + } + { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001)); + assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); + } + } + + public void testAggs() { + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); + List aggs = confusionMatrix.aggs("act", "pred"); + assertThat(aggs, is(not(empty()))); + assertThat(confusionMatrix.getResult(), equalTo(Optional.empty())); + } + + public void testEvaluate() { + Aggregations aggs = new Aggregations(List.of( + mockTerms( + "multiclass_confusion_matrix_step_1_by_actual_class", + List.of( + mockTermsBucket("dog", new Aggregations(List.of())), + mockTermsBucket("cat", new Aggregations(List.of()))), + 0L), + mockTerms( + "multiclass_confusion_matrix_step_2_by_actual_class", + List.of( + mockTermsBucket( + "dog", + new Aggregations(List.of(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockTermsBucket( + "cat", + new Aggregations(List.of(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))), + 0L), + mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); + + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + confusionMatrix.process(aggs); + + assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); + MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + assertThat( + result.getConfusionMatrix(), + equalTo(Map.of("dog", Map.of("cat", 10L, "dog", 20L), "cat", Map.of("cat", 30L, "dog", 40L)))); + assertThat(result.getOtherClassesCount(), equalTo(0L)); + } + + public void testEvaluate_OtherClassesCountGreaterThanZero() { + Aggregations aggs = new Aggregations(List.of( + mockTerms( + "multiclass_confusion_matrix_step_1_by_actual_class", + List.of( + mockTermsBucket("dog", new Aggregations(List.of())), + mockTermsBucket("cat", new Aggregations(List.of()))), + 100L), + mockTerms( + "multiclass_confusion_matrix_step_2_by_actual_class", + List.of( + mockTermsBucket( + "dog", + new Aggregations(List.of(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockTermsBucket( + "cat", + new Aggregations(List.of(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))), + 100L), + mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); + + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + confusionMatrix.process(aggs); + + assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); + MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + assertThat( + result.getConfusionMatrix(), + equalTo(Map.of("dog", Map.of("cat", 10L, "dog", 20L), "cat", Map.of("cat", 30L, "dog", 40L, "_other_", 15L)))); + assertThat(result.getOtherClassesCount(), equalTo(3L)); + } + + private static Terms mockTerms(String name, List buckets, long sumOfOtherDocCounts) { + Terms aggregation = mock(Terms.class); + when(aggregation.getName()).thenReturn(name); + doReturn(buckets).when(aggregation).getBuckets(); + when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts); + return aggregation; + } + + private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) { + Terms.Bucket bucket = mock(Terms.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(actualClass); + when(bucket.getAggregations()).thenReturn(subAggs); + return bucket; + } + + private static Filters mockFilters(String name, List buckets) { + Filters aggregation = mock(Filters.class); + when(aggregation.getName()).thenReturn(name); + doReturn(buckets).when(aggregation).getBuckets(); + return aggregation; + } + + private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) { + Filters.Bucket bucket = mock(Filters.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(predictedClass); + when(bucket.getDocCount()).thenReturn(docCount); + return bucket; + } + + private static Cardinality mockCardinality(String name, long value) { + Cardinality aggregation = mock(Cardinality.class); + when(aggregation.getName()).thenReturn(name); + when(aggregation.getValue()).thenReturn(value); + return aggregation; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 3793cffbf0e1e..8911a0dc2781c 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -90,6 +90,9 @@ integTest.runner { 'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', + 'ml/evaluate_data_frame/Test classification given evaluation with empty metrics', + 'ml/evaluate_data_frame/Test classification given missing actual_field', + 'ml/evaluate_data_frame/Test classification given missing predicted_field', 'ml/evaluate_data_frame/Test regression given evaluation with empty metrics', 'ml/evaluate_data_frame/Test regression given missing actual_field', 'ml/evaluate_data_frame/Test regression given missing predicted_field', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java new file mode 100644 index 0000000000000..5defc6740df99 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; + +public class EvaluateDataFrameIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; + + private static final String ACTUAL_CLASS_FIELD = "actual_class_field"; + private static final String PREDICTED_CLASS_FIELD = "predicted_class_field"; + + @Before + public void setup() { + indexAnimalsData(ANIMALS_DATA_INDEX); + } + + @After + public void cleanup() { + cleanUp(); + } + + public void testEvaluate_MulticlassClassification_DefaultMetrics() { + EvaluateDataFrameAction.Request evaluateDataFrameRequest = + new EvaluateDataFrameAction.Request() + .setIndices(List.of(ANIMALS_DATA_INDEX)) + .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null)); + + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + assertThat( + confusionMatrixResult.getConfusionMatrix(), + equalTo(Map.of( + "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "fox", 2L, "mouse", 5L), + "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "fox", 4L, "mouse", 2L), + "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "fox", 5L, "mouse", 3L), + "fox", Map.of("ant", 5L, "cat", 3L, "dog", 2L, "fox", 1L, "mouse", 4L), + "mouse", Map.of("ant", 2L, "cat", 5L, "dog", 4L, "fox", 3L, "mouse", 1L)))); + assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + } + + public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { + EvaluateDataFrameAction.Request evaluateDataFrameRequest = + new EvaluateDataFrameAction.Request() + .setIndices(List.of(ANIMALS_DATA_INDEX)) + .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix()))); + + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + assertThat( + confusionMatrixResult.getConfusionMatrix(), + equalTo(Map.of( + "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "fox", 2L, "mouse", 5L), + "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "fox", 4L, "mouse", 2L), + "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "fox", 5L, "mouse", 3L), + "fox", Map.of("ant", 5L, "cat", 3L, "dog", 2L, "fox", 1L, "mouse", 4L), + "mouse", Map.of("ant", 2L, "cat", 5L, "dog", 4L, "fox", 3L, "mouse", 1L)))); + assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + } + + public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { + EvaluateDataFrameAction.Request evaluateDataFrameRequest = + new EvaluateDataFrameAction.Request() + .setIndices(List.of(ANIMALS_DATA_INDEX)) + .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix(3)))); + + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + assertThat( + confusionMatrixResult.getConfusionMatrix(), + equalTo(Map.of( + "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "_other_", 7L), + "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "_other_", 6L), + "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "_other_", 8L)))); + assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L)); + } + + private static void indexAnimalsData(String indexName) { + client().admin().indices().prepareCreate(indexName) + .addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword") + .get(); + + List classNames = List.of("dog", "cat", "mouse", "ant", "fox"); + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < classNames.size(); i++) { + for (int j = 0; j < classNames.size(); j++) { + for (int k = 0; k < j + 1; k++) { + bulkRequestBuilder.add( + new IndexRequest(indexName) + .source( + ACTUAL_CLASS_FIELD, classNames.get(i), + PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size()))); + } + } + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 7459e6959016b..92e816c492d28 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -11,6 +11,8 @@ setup: "outlier_score": 0.0, "regression_field_act": 10.9, "regression_field_pred": 10.9, + "classification_field_act": "dog", + "classification_field_pred": "dog", "all_true_field": true, "all_false_field": false } @@ -26,6 +28,8 @@ setup: "outlier_score": 0.2, "regression_field_act": 12.0, "regression_field_pred": 9.9, + "classification_field_act": "cat", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -41,6 +45,8 @@ setup: "outlier_score": 0.3, "regression_field_act": 20.9, "regression_field_pred": 5.9, + "classification_field_act": "mouse", + "classification_field_pred": "mouse", "all_true_field": true, "all_false_field": false } @@ -56,6 +62,8 @@ setup: "outlier_score": 0.3, "regression_field_act": 11.9, "regression_field_pred": 11.9, + "classification_field_act": "dog", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -71,6 +79,8 @@ setup: "outlier_score": 0.4, "regression_field_act": 42.9, "regression_field_pred": 42.9, + "classification_field_act": "cat", + "classification_field_pred": "dog", "all_true_field": true, "all_false_field": false } @@ -86,6 +96,8 @@ setup: "outlier_score": 0.5, "regression_field_act": 0.42, "regression_field_pred": 0.42, + "classification_field_act": "dog", + "classification_field_pred": "dog", "all_true_field": true, "all_false_field": false } @@ -101,6 +113,8 @@ setup: "outlier_score": 0.9, "regression_field_act": 1.1235813, "regression_field_pred": 1.12358, + "classification_field_act": "cat", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -116,6 +130,8 @@ setup: "outlier_score": 0.95, "regression_field_act": -5.20, "regression_field_pred": -5.1, + "classification_field_act": "mouse", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -569,6 +585,108 @@ setup: } } } + +--- +"Test classification given evaluation with empty metrics": + - do: + catch: /\[classification\] must have one or more metrics/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { } + } + } + } +--- +"Test classification multiclass_confusion_matrix": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "multiclass_confusion_matrix": {} } + } + } + } + + - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } + - match: { classification.multiclass_confusion_matrix.other_classes_count: 0 } +--- +"Test classification multiclass_confusion_matrix with explicit size": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "multiclass_confusion_matrix": { "size": 2 } } + } + } + } + + - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } } + - match: { classification.multiclass_confusion_matrix.other_classes_count: 1 } +--- +"Test classification with null metrics": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword" + } + } + } + + - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } + - match: { classification.multiclass_confusion_matrix.other_classes_count: 0 } +--- +"Test classification given missing actual_field": + - do: + catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "missing", + "predicted_field": "classification_field_pred.keyword" + } + } + } + +--- +"Test classification given missing predicted_field": + - do: + catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "missing" + } + } + } + --- "Test regression given evaluation with empty metrics": - do: From d8472f0f07dc62398664c193607fce5d7ca6036b Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 27 Sep 2019 12:19:43 +0200 Subject: [PATCH 2/2] Apply review comments --- .../MulticlassConfusionMatrixMetric.java | 2 +- .../classification/MulticlassConfusionMatrix.java | 14 +++++++------- ...rameIT.java => ClassificationEvaluationIT.java} | 2 +- .../rest-api-spec/test/ml/evaluate_data_frame.yml | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) rename x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/{EvaluateDataFrameIT.java => ClassificationEvaluationIT.java} (98%) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index ba09d8bb202af..a8e8545009b25 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -97,7 +97,7 @@ public int hashCode() { public static class Result implements EvaluationMetric.Result { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("other_classes_count"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index f4e0c723da5d4..d9d47ab9aab20 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -103,13 +103,14 @@ public int getSize() { @Override public final List aggs(String actualField, String predictedField) { - if (topActualClassNames == null) { + if (topActualClassNames == null) { // This is step 1 return List.of( AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) .field(actualField) .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) .size(size)); - } else if (result == null) { + } + if (result == null) { // This is step 2 KeyedFilter[] keyedFilters = topActualClassNames.stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) @@ -124,18 +125,17 @@ public final List aggs(String actualField, String predictedF .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters) .otherBucket(true) .otherBucketKey(OTHER_BUCKET_KEY))); - } else { - return List.of(); } + return List.of(); } @Override public void process(Aggregations aggs) { - if (aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null && topActualClassNames == null) { + if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList()); } - if (aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null && result == null) { + if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); Map> counts = new TreeMap<>(); @@ -190,7 +190,7 @@ public int hashCode() { public static class Result implements EvaluationMetricResult { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("other_classes_count"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java similarity index 98% rename from x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java rename to x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 5defc6740df99..2cfa98a28aaa9 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EvaluateDataFrameIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -20,7 +20,7 @@ import static org.hamcrest.Matchers.equalTo; -public class EvaluateDataFrameIT extends MlNativeDataFrameAnalyticsIntegTestCase { +public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 92e816c492d28..1bcde11f2fb74 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -619,7 +619,7 @@ setup: } - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } - - match: { classification.multiclass_confusion_matrix.other_classes_count: 0 } + - match: { classification.multiclass_confusion_matrix._other_: 0 } --- "Test classification multiclass_confusion_matrix with explicit size": - do: @@ -637,7 +637,7 @@ setup: } - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } } - - match: { classification.multiclass_confusion_matrix.other_classes_count: 1 } + - match: { classification.multiclass_confusion_matrix._other_: 1 } --- "Test classification with null metrics": - do: @@ -654,7 +654,7 @@ setup: } - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } - - match: { classification.multiclass_confusion_matrix.other_classes_count: 0 } + - match: { classification.multiclass_confusion_matrix._other_: 0 } --- "Test classification given missing actual_field": - do: