From 01e61852e6b0ea50c59ef843663da190c8673c31 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 17 Oct 2019 09:31:56 +0200 Subject: [PATCH 1/6] Change format of MulticlassConfusionMatrix result to be more self-explanatory. --- .../MulticlassConfusionMatrixMetric.java | 121 ++++++++++- .../client/MachineLearningIT.java | 23 +- .../MlClientDocumentationIT.java | 16 +- ...classConfusionMatrixMetricResultTests.java | 25 ++- .../MulticlassConfusionMatrix.java | 205 ++++++++++++++---- .../MulticlassConfusionMatrixResultTests.java | 31 +-- .../MulticlassConfusionMatrixTests.java | 38 ++-- .../ClassificationEvaluationIT.java | 110 ++++++++-- .../test/ml/evaluate_data_frame.yml | 14 +- 9 files changed, 463 insertions(+), 120 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index a8e8545009b25..3bc0c3a7499c0 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -22,14 +22,14 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Collections; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.TreeMap; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -102,13 +102,10 @@ public static class Result implements EvaluationMetric.Result { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + "multiclass_confusion_matrix_result", true, a -> new Result((List) a[0], (long) a[1])); static { - PARSER.declareObject( - constructorArg(), - (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), - CONFUSION_MATRIX); + PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); } @@ -117,11 +114,11 @@ public static Result fromXContent(XContentParser parser) { } // Immutable - private final Map> confusionMatrix; + private final List confusionMatrix; private final long otherClassesCount; - public Result(Map> confusionMatrix, long otherClassesCount) { - this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); + public Result(List confusionMatrix, long otherClassesCount) { + this.confusionMatrix = Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)); this.otherClassesCount = otherClassesCount; } @@ -130,7 +127,7 @@ public String getMetricName() { return NAME; } - public Map> getConfusionMatrix() { + public List getConfusionMatrix() { return confusionMatrix; } @@ -161,4 +158,106 @@ public int hashCode() { return Objects.hash(confusionMatrix, otherClassesCount); } } + + public static class ActualClass implements ToXContentObject { + + private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_actual_class", + true, + a -> new ActualClass((String) a[0], (List) a[1], (long) a[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); + PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + } + + private final String actualClass; + private final List predictedClasses; + private final long otherClassesCount; + + public ActualClass(String actualClass, List predictedClasses, long otherClassesCount) { + this.actualClass = actualClass; + this.predictedClasses = Collections.unmodifiableList(predictedClasses); + this.otherClassesCount = otherClassesCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); + builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ActualClass that = (ActualClass) o; + return Objects.equals(this.actualClass, that.actualClass) + && Objects.equals(this.predictedClasses, that.predictedClasses) + && this.otherClassesCount == that.otherClassesCount; + } + + @Override + public int hashCode() { + return Objects.hash(actualClass, predictedClasses, otherClassesCount); + } + } + + public static class PredictedClass implements ToXContentObject { + + private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class"); + private static final ParseField COUNT = new ParseField("count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1])); + + static { + PARSER.declareString(constructorArg(), PREDICTED_CLASS); + PARSER.declareLong(constructorArg(), COUNT); + } + + private final String predictedClass; + private final Long count; + + public PredictedClass(String predictedClass, Long count) { + this.predictedClass = predictedClass; + this.count = count; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass); + builder.field(COUNT.getPreferredName(), count); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PredictedClass that = (PredictedClass) o; + return Objects.equals(this.predictedClass, that.predictedClass) + && Objects.equals(this.count, that.count); + } + + @Override + public int hashCode() { + return Objects.hash(predictedClass, count); + } + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 7a96846dcdce8..47f6ad68a1594 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -127,6 +127,8 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; @@ -1777,7 +1779,7 @@ public void testEvaluateDataFrame_Classification() throws IOException { .add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog")) - .add(docForClassification(indexName, "horse", "cat")); + .add(docForClassification(indexName, "ant", "cat")); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1800,10 +1802,14 @@ public void testEvaluateDataFrame_Classification() throws IOException { assertThat( mcmResult.getConfusionMatrix(), equalTo( - Map.of( - "cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L), - "dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L), - "horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L)))); + List.of( + new ActualClass( + "ant", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), 0), + new ActualClass( + "cat", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), + new ActualClass( + "dog", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) + ))); assertThat(mcmResult.getOtherClassesCount(), equalTo(0L)); } { // Explicit size provided for MulticlassConfusionMatrixMetric metric @@ -1824,9 +1830,10 @@ public void testEvaluateDataFrame_Classification() throws IOException { assertThat( mcmResult.getConfusionMatrix(), equalTo( - Map.of( - "cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L), - "dog", Map.of("cat", 1L, "dog", 3L)))); + List.of( + new ActualClass("cat", List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), + new ActualClass("dog", List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) + ))); assertThat(mcmResult.getOtherClassesCount(), equalTo(1L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index cfc0d2a191942..d2f12e139878e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -142,6 +142,8 @@ import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -3355,7 +3357,7 @@ public void testEvaluateDataFrame_Classification() throws Exception { MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1> - Map> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> + List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3> // end::evaluate-data-frame-results-classification @@ -3363,10 +3365,14 @@ public void testEvaluateDataFrame_Classification() throws Exception { assertThat( confusionMatrix, equalTo( - Map.of( - "cat", Map.of("cat", 3L, "dog", 1L, "ant", 0L, "_other_", 1L), - "dog", Map.of("cat", 1L, "dog", 3L, "ant", 0L), - "ant", Map.of("cat", 1L, "dog", 0L, "ant", 0L)))); + List.of( + new ActualClass( + "ant", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), 0), + new ActualClass( + "cat", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), + new ActualClass( + "dog", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) + ))); assertThat(otherClassesCount, equalTo(0L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java index 800a2cf7b9836..4f39fbfbe4f42 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -19,19 +19,21 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; +import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.TreeMap; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase { +public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase { @Override protected NamedXContentRegistry xContentRegistry() { @@ -39,26 +41,25 @@ protected NamedXContentRegistry xContentRegistry() { } @Override - protected MulticlassConfusionMatrixMetric.Result createTestInstance() { + protected Result createTestInstance() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - Map> confusionMatrix = new TreeMap<>(); + List actualClasses = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { - Map row = new TreeMap<>(); - confusionMatrix.put(classNames.get(i), row); + List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { if (randomBoolean()) { - row.put(classNames.get(i), randomNonNegativeLong()); + predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } } + actualClasses.add(new ActualClass(classNames.get(i), predictedClasses, randomNonNegativeLong())); } - long otherClassesCount = randomNonNegativeLong(); - return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount); + return new Result(actualClasses, randomNonNegativeLong()); } @Override - protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { - return MulticlassConfusionMatrixMetric.Result.fromXContent(parser); + protected Result doParseInstance(XContentParser parser) throws IOException { + return Result.fromXContent(parser); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index d9d47ab9aab20..bc41e422cb1c5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -9,7 +9,9 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilders; @@ -25,14 +27,14 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.TreeMap; import java.util.stream.Collectors; +import static java.util.Comparator.comparing; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -111,18 +113,19 @@ public final List aggs(String actualField, String predictedF .size(size)); } if (result == null) { // This is step 2 - KeyedFilter[] keyedFilters = + KeyedFilter[] keyedFiltersActual = + topActualClassNames.stream() + .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) + .toArray(KeyedFilter[]::new); + KeyedFilter[] keyedFiltersPredicted = topActualClassNames.stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); return List.of( AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) .field(actualField), - AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) - .field(actualField) - .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) - .size(size) - .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters) + AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) + .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) .otherBucket(true) .otherBucketKey(OTHER_BUCKET_KEY))); } @@ -133,26 +136,30 @@ public final List aggs(String actualField, String predictedF public void process(Aggregations aggs) { if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); - topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList()); + topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); } if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); - Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); - Map> counts = new TreeMap<>(); - for (Terms.Bucket bucket : termsAgg.getBuckets()) { + Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); + List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); + for (Filters.Bucket bucket : filtersAgg.getBuckets()) { String actualClass = bucket.getKeyAsString(); - Map subCounts = new TreeMap<>(); - counts.put(actualClass, subCounts); Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); + List predictedClasses = new ArrayList<>(); + long otherClassCount = 0; for (Filters.Bucket subBucket : subAgg.getBuckets()) { String predictedClass = subBucket.getKeyAsString(); - Long docCount = subBucket.getDocCount(); - if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) { - subCounts.put(predictedClass, docCount); + long docCount = subBucket.getDocCount(); + if (OTHER_BUCKET_KEY.equals(predictedClass)) { + otherClassCount = docCount; + } else { + predictedClasses.add(new PredictedClass(predictedClass, docCount)); } } + predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); + actualClasses.add(new ActualClass(actualClass, predictedClasses, otherClassCount)); } - result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size); + result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); } } @@ -192,15 +199,13 @@ public static class Result implements EvaluationMetricResult { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + "multiclass_confusion_matrix_result", true, a -> new Result((List) a[0], (long) a[1])); static { - PARSER.declareObject( - constructorArg(), - (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), - CONFUSION_MATRIX); + PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); } @@ -209,17 +214,16 @@ public static Result fromXContent(XContentParser parser) { } // Immutable - private final Map> confusionMatrix; + private final List actualClasses; private final long otherClassesCount; - public Result(Map> confusionMatrix, long otherClassesCount) { - this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); + public Result(List actualClasses, long otherClassesCount) { + this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); this.otherClassesCount = otherClassesCount; } public Result(StreamInput in) throws IOException { - this.confusionMatrix = Collections.unmodifiableMap( - in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong))); + this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); this.otherClassesCount = in.readLong(); } @@ -233,8 +237,8 @@ public String getMetricName() { return NAME.getPreferredName(); } - public Map> getConfusionMatrix() { - return confusionMatrix; + public List getConfusionMatrix() { + return actualClasses; } public long getOtherClassesCount() { @@ -243,17 +247,14 @@ public long getOtherClassesCount() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap( - confusionMatrix, - StreamOutput::writeString, - (out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong)); + out.writeList(actualClasses); out.writeLong(otherClassesCount); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + builder.field(CONFUSION_MATRIX.getPreferredName(), actualClasses); builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); builder.endObject(); return builder; @@ -264,13 +265,143 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return Objects.equals(this.confusionMatrix, that.confusionMatrix) + return Objects.equals(this.actualClasses, that.actualClasses) + && this.otherClassesCount == that.otherClassesCount; + } + + @Override + public int hashCode() { + return Objects.hash(actualClasses, otherClassesCount); + } + } + + public static class ActualClass implements ToXContentObject, Writeable { + + private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_actual_class", + true, + a -> new ActualClass((String) a[0], (List) a[1], (long) a[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); + PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + } + + private final String actualClass; + private final List predictedClasses; + private final long otherClassesCount; + + public ActualClass(String actualClass, List predictedClasses, long otherClassesCount) { + this.actualClass = actualClass; + this.predictedClasses = Collections.unmodifiableList(predictedClasses); + this.otherClassesCount = otherClassesCount; + } + + public ActualClass(StreamInput in) throws IOException { + this.actualClass = in.readString(); + this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new)); + this.otherClassesCount = in.readLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualClass); + out.writeList(predictedClasses); + out.writeLong(otherClassesCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); + builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ActualClass that = (ActualClass) o; + return Objects.equals(this.actualClass, that.actualClass) + && Objects.equals(this.predictedClasses, that.predictedClasses) && this.otherClassesCount == that.otherClassesCount; } @Override public int hashCode() { - return Objects.hash(confusionMatrix, otherClassesCount); + return Objects.hash(actualClass, predictedClasses, otherClassesCount); + } + } + + public static class PredictedClass implements ToXContentObject, Writeable { + + private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class"); + private static final ParseField COUNT = new ParseField("count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1])); + + static { + PARSER.declareString(constructorArg(), PREDICTED_CLASS); + PARSER.declareLong(constructorArg(), COUNT); + } + + private final String predictedClass; + private final long count; + + public PredictedClass(String predictedClass, long count) { + this.predictedClass = predictedClass; + this.count = count; + } + + public PredictedClass(StreamInput in) throws IOException { + this.predictedClass = in.readString(); + this.count = in.readLong(); + } + + public String getPredictedClass() { + return predictedClass; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(predictedClass); + out.writeLong(count); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass); + builder.field(COUNT.getPreferredName(), count); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PredictedClass that = (PredictedClass) o; + return Objects.equals(this.predictedClass, that.predictedClass) + && this.count == that.count; + } + + @Override + public int hashCode() { + return Objects.hash(predictedClass, count); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java index 24b13d372d528..adddbad9f2497 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -8,47 +8,48 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; import java.io.IOException; +import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.TreeMap; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { +public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { - public static MulticlassConfusionMatrix.Result createRandom() { + public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - Map> confusionMatrix = new TreeMap<>(); + List actualClasses = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { - Map row = new TreeMap<>(); - confusionMatrix.put(classNames.get(i), row); + List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { if (randomBoolean()) { - row.put(classNames.get(i), randomNonNegativeLong()); + predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } } + actualClasses.add(new ActualClass(classNames.get(i), predictedClasses, randomNonNegativeLong())); } - long otherClassesCount = randomNonNegativeLong(); - return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount); + return new Result(actualClasses, randomNonNegativeLong()); } @Override - protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException { - return MulticlassConfusionMatrix.Result.fromXContent(parser); + protected Result doParseInstance(XContentParser parser) throws IOException { + return Result.fromXContent(parser); } @Override - protected MulticlassConfusionMatrix.Result createTestInstance() { + protected Result createTestInstance() { return createRandom(); } @Override - protected Writeable.Reader instanceReader() { - return MulticlassConfusionMatrix.Result::new; + protected Writeable.Reader instanceReader() { + return Result::new; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index a4e989bce898a..131a3e71147af 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -14,10 +14,11 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.Cardinality; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import java.io.IOException; import java.util.List; -import java.util.Map; import java.util.Optional; import static org.hamcrest.Matchers.empty; @@ -85,20 +86,19 @@ public void testEvaluate() { mockTermsBucket("dog", new Aggregations(List.of())), mockTermsBucket("cat", new Aggregations(List.of()))), 0L), - mockTerms( + mockFilters( "multiclass_confusion_matrix_step_2_by_actual_class", List.of( - mockTermsBucket( + mockFiltersBucket( "dog", new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), - mockTermsBucket( + mockFiltersBucket( "cat", new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", - List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))), - 0L), + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); @@ -109,7 +109,10 @@ public void testEvaluate() { assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat( result.getConfusionMatrix(), - equalTo(Map.of("dog", Map.of("cat", 10L, "dog", 20L), "cat", Map.of("cat", 30L, "dog", 40L)))); + equalTo( + List.of( + new ActualClass("dog", List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), + new ActualClass("cat", List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0)))); assertThat(result.getOtherClassesCount(), equalTo(0L)); } @@ -121,20 +124,19 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { mockTermsBucket("dog", new Aggregations(List.of())), mockTermsBucket("cat", new Aggregations(List.of()))), 100L), - mockTerms( + mockFilters( "multiclass_confusion_matrix_step_2_by_actual_class", List.of( - mockTermsBucket( + mockFiltersBucket( "dog", new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), - mockTermsBucket( + mockFiltersBucket( "cat", new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", - List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))), - 100L), + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); @@ -145,7 +147,10 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat( result.getConfusionMatrix(), - equalTo(Map.of("dog", Map.of("cat", 10L, "dog", 20L), "cat", Map.of("cat", 30L, "dog", 40L, "_other_", 15L)))); + equalTo( + List.of( + new ActualClass("dog", List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), + new ActualClass("cat", List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); assertThat(result.getOtherClassesCount(), equalTo(3L)); } @@ -171,6 +176,13 @@ private static Filters mockFilters(String name, List buckets) { return aggregation; } + private static Filters.Bucket mockFiltersBucket(String actualClass, Aggregations subAggs) { + Filters.Bucket bucket = mock(Filters.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(actualClass); + when(bucket.getAggregations()).thenReturn(subAggs); + return bucket; + } + private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) { Filters.Bucket bucket = mock(Filters.Bucket.class); when(bucket.getKeyAsString()).thenReturn(predictedClass); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 2cfa98a28aaa9..a506433728394 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -12,6 +12,8 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.junit.After; import org.junit.Before; @@ -53,12 +55,47 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( confusionMatrixResult.getConfusionMatrix(), - equalTo(Map.of( - "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "fox", 2L, "mouse", 5L), - "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "fox", 4L, "mouse", 2L), - "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "fox", 5L, "mouse", 3L), - "fox", Map.of("ant", 5L, "cat", 3L, "dog", 2L, "fox", 1L, "mouse", 4L), - "mouse", Map.of("ant", 2L, "cat", 5L, "dog", 4L, "fox", 3L, "mouse", 1L)))); + equalTo(List.of( + new ActualClass("ant", + List.of( + new PredictedClass("ant", 1L), + new PredictedClass("cat", 4L), + new PredictedClass("dog", 3L), + new PredictedClass("fox", 2L), + new PredictedClass("mouse", 5L)), + 0), + new ActualClass("cat", + List.of( + new PredictedClass("ant", 3L), + new PredictedClass("cat", 1L), + new PredictedClass("dog", 5L), + new PredictedClass("fox", 4L), + new PredictedClass("mouse", 2L)), + 0), + new ActualClass("dog", + List.of( + new PredictedClass("ant", 4L), + new PredictedClass("cat", 2L), + new PredictedClass("dog", 1L), + new PredictedClass("fox", 5L), + new PredictedClass("mouse", 3L)), + 0), + new ActualClass("fox", + List.of( + new PredictedClass("ant", 5L), + new PredictedClass("cat", 3L), + new PredictedClass("dog", 2L), + new PredictedClass("fox", 1L), + new PredictedClass("mouse", 4L)), + 0), + new ActualClass("mouse", + List.of( + new PredictedClass("ant", 2L), + new PredictedClass("cat", 5L), + new PredictedClass("dog", 4L), + new PredictedClass("fox", 3L), + new PredictedClass("mouse", 1L)), + 0)))); assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); } @@ -78,12 +115,47 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( confusionMatrixResult.getConfusionMatrix(), - equalTo(Map.of( - "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "fox", 2L, "mouse", 5L), - "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "fox", 4L, "mouse", 2L), - "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "fox", 5L, "mouse", 3L), - "fox", Map.of("ant", 5L, "cat", 3L, "dog", 2L, "fox", 1L, "mouse", 4L), - "mouse", Map.of("ant", 2L, "cat", 5L, "dog", 4L, "fox", 3L, "mouse", 1L)))); + equalTo(List.of( + new ActualClass("ant", + List.of( + new PredictedClass("ant", 1L), + new PredictedClass("cat", 4L), + new PredictedClass("dog", 3L), + new PredictedClass("fox", 2L), + new PredictedClass("mouse", 5L)), + 0), + new ActualClass("cat", + List.of( + new PredictedClass("ant", 3L), + new PredictedClass("cat", 1L), + new PredictedClass("dog", 5L), + new PredictedClass("fox", 4L), + new PredictedClass("mouse", 2L)), + 0), + new ActualClass("dog", + List.of( + new PredictedClass("ant", 4L), + new PredictedClass("cat", 2L), + new PredictedClass("dog", 1L), + new PredictedClass("fox", 5L), + new PredictedClass("mouse", 3L)), + 0), + new ActualClass("fox", + List.of( + new PredictedClass("ant", 5L), + new PredictedClass("cat", 3L), + new PredictedClass("dog", 2L), + new PredictedClass("fox", 1L), + new PredictedClass("mouse", 4L)), + 0), + new ActualClass("mouse", + List.of( + new PredictedClass("ant", 2L), + new PredictedClass("cat", 5L), + new PredictedClass("dog", 4L), + new PredictedClass("fox", 3L), + new PredictedClass("mouse", 1L)), + 0)))); assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); } @@ -103,10 +175,16 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( confusionMatrixResult.getConfusionMatrix(), - equalTo(Map.of( - "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "_other_", 7L), - "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "_other_", 6L), - "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "_other_", 8L)))); + equalTo(List.of( + new ActualClass("ant", + List.of(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)), + 7), + new ActualClass("cat", + List.of(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)), + 6), + new ActualClass("dog", + List.of(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)), + 8)))); assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L)); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 1bcde11f2fb74..6932f0ff63ef5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -618,7 +618,10 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } + - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ + { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, + { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, + { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], _other_: 0 }]} - match: { classification.multiclass_confusion_matrix._other_: 0 } --- "Test classification multiclass_confusion_matrix with explicit size": @@ -636,7 +639,9 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } } + - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ + { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 } ], _other_: 0 }, + { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 } ], _other_: 0 }]} - match: { classification.multiclass_confusion_matrix._other_: 1 } --- "Test classification with null metrics": @@ -653,7 +658,10 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } + - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ + { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, + { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, + { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], _other_: 0 }]} - match: { classification.multiclass_confusion_matrix._other_: 0 } --- "Test classification given missing actual_field": From f150f5aa1d383b091195362f6ba5f0a26d30ece7 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 17 Oct 2019 11:59:31 +0200 Subject: [PATCH 2/6] Remove unused import --- .../xpack/ml/integration/ClassificationEvaluationIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index a506433728394..3e4b45d3a6fbb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -18,7 +18,6 @@ import org.junit.Before; import java.util.List; -import java.util.Map; import static org.hamcrest.Matchers.equalTo; From 604b14e165b14dcf2bbc171b6306b24f98986e28 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 17 Oct 2019 12:15:09 +0200 Subject: [PATCH 3/6] Rename _other_ field --- .../MulticlassConfusionMatrixMetric.java | 36 ++++++------- .../client/MachineLearningIT.java | 4 +- .../MlClientDocumentationIT.java | 2 +- .../MulticlassConfusionMatrix.java | 50 +++++++++---------- .../MulticlassConfusionMatrixTests.java | 4 +- .../ClassificationEvaluationIT.java | 6 +-- .../test/ml/evaluate_data_frame.yml | 22 ++++---- 7 files changed, 62 insertions(+), 62 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index 3bc0c3a7499c0..3da6bbae37932 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -97,7 +97,7 @@ public int hashCode() { public static class Result implements EvaluationMetric.Result { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = @@ -106,7 +106,7 @@ public static class Result implements EvaluationMetric.Result { static { PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); - PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT); } public static Result fromXContent(XContentParser parser) { @@ -115,11 +115,11 @@ public static Result fromXContent(XContentParser parser) { // Immutable private final List confusionMatrix; - private final long otherClassesCount; + private final long otherActualClassCount; - public Result(List confusionMatrix, long otherClassesCount) { + public Result(List confusionMatrix, long otherActualClassCount) { this.confusionMatrix = Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)); - this.otherClassesCount = otherClassesCount; + this.otherActualClassCount = otherActualClassCount; } @Override @@ -131,15 +131,15 @@ public List getConfusionMatrix() { return confusionMatrix; } - public long getOtherClassesCount() { - return otherClassesCount; + public long getOtherActualClassCount() { + return otherActualClassCount; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); - builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount); builder.endObject(); return builder; } @@ -150,12 +150,12 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; return Objects.equals(this.confusionMatrix, that.confusionMatrix) - && this.otherClassesCount == that.otherClassesCount; + && this.otherActualClassCount == that.otherActualClassCount; } @Override public int hashCode() { - return Objects.hash(confusionMatrix, otherClassesCount); + return Objects.hash(confusionMatrix, otherActualClassCount); } } @@ -163,7 +163,7 @@ public static class ActualClass implements ToXContentObject { private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = @@ -175,17 +175,17 @@ public static class ActualClass implements ToXContentObject { static { PARSER.declareString(constructorArg(), ACTUAL_CLASS); PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); - PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT); } private final String actualClass; private final List predictedClasses; - private final long otherClassesCount; + private final long otherPredictedClassCount; - public ActualClass(String actualClass, List predictedClasses, long otherClassesCount) { + public ActualClass(String actualClass, List predictedClasses, long otherPredictedClassCount) { this.actualClass = actualClass; this.predictedClasses = Collections.unmodifiableList(predictedClasses); - this.otherClassesCount = otherClassesCount; + this.otherPredictedClassCount = otherPredictedClassCount; } @Override @@ -193,7 +193,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); - builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount); builder.endObject(); return builder; } @@ -205,12 +205,12 @@ public boolean equals(Object o) { ActualClass that = (ActualClass) o; return Objects.equals(this.actualClass, that.actualClass) && Objects.equals(this.predictedClasses, that.predictedClasses) - && this.otherClassesCount == that.otherClassesCount; + && this.otherPredictedClassCount == that.otherPredictedClassCount; } @Override public int hashCode() { - return Objects.hash(actualClass, predictedClasses, otherClassesCount); + return Objects.hash(actualClass, predictedClasses, otherPredictedClassCount); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 47f6ad68a1594..3e7a9961e08aa 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1810,7 +1810,7 @@ public void testEvaluateDataFrame_Classification() throws IOException { new ActualClass( "dog", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) ))); - assertThat(mcmResult.getOtherClassesCount(), equalTo(0L)); + assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L)); } { // Explicit size provided for MulticlassConfusionMatrixMetric metric EvaluateDataFrameRequest evaluateDataFrameRequest = @@ -1834,7 +1834,7 @@ public void testEvaluateDataFrame_Classification() throws IOException { new ActualClass("cat", List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), new ActualClass("dog", List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) ))); - assertThat(mcmResult.getOtherClassesCount(), equalTo(1L)); + assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index d2f12e139878e..60eab68faa8d6 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3358,7 +3358,7 @@ public void testEvaluateDataFrame_Classification() throws Exception { response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1> List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> - long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3> + long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3> // end::evaluate-data-frame-results-classification assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index bc41e422cb1c5..d04c0e3412cb7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -146,18 +146,18 @@ public void process(Aggregations aggs) { String actualClass = bucket.getKeyAsString(); Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); List predictedClasses = new ArrayList<>(); - long otherClassCount = 0; + long otherPredictedClassCount = 0; for (Filters.Bucket subBucket : subAgg.getBuckets()) { String predictedClass = subBucket.getKeyAsString(); long docCount = subBucket.getDocCount(); if (OTHER_BUCKET_KEY.equals(predictedClass)) { - otherClassCount = docCount; + otherPredictedClassCount = docCount; } else { predictedClasses.add(new PredictedClass(predictedClass, docCount)); } } predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); - actualClasses.add(new ActualClass(actualClass, predictedClasses, otherClassCount)); + actualClasses.add(new ActualClass(actualClass, predictedClasses, otherPredictedClassCount)); } result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); } @@ -197,7 +197,7 @@ public int hashCode() { public static class Result implements EvaluationMetricResult { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = @@ -206,7 +206,7 @@ public static class Result implements EvaluationMetricResult { static { PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); - PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT); } public static Result fromXContent(XContentParser parser) { @@ -215,16 +215,16 @@ public static Result fromXContent(XContentParser parser) { // Immutable private final List actualClasses; - private final long otherClassesCount; + private final long otherActualClassCount; - public Result(List actualClasses, long otherClassesCount) { + public Result(List actualClasses, long otherActualClassCount) { this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); - this.otherClassesCount = otherClassesCount; + this.otherActualClassCount = otherActualClassCount; } public Result(StreamInput in) throws IOException { this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); - this.otherClassesCount = in.readLong(); + this.otherActualClassCount = in.readLong(); } @Override @@ -241,21 +241,21 @@ public List getConfusionMatrix() { return actualClasses; } - public long getOtherClassesCount() { - return otherClassesCount; + public long getOtherActualClassCount() { + return otherActualClassCount; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeList(actualClasses); - out.writeLong(otherClassesCount); + out.writeLong(otherActualClassCount); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(CONFUSION_MATRIX.getPreferredName(), actualClasses); - builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount); builder.endObject(); return builder; } @@ -266,12 +266,12 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; return Objects.equals(this.actualClasses, that.actualClasses) - && this.otherClassesCount == that.otherClassesCount; + && this.otherActualClassCount == that.otherActualClassCount; } @Override public int hashCode() { - return Objects.hash(actualClasses, otherClassesCount); + return Objects.hash(actualClasses, otherActualClassCount); } } @@ -279,7 +279,7 @@ public static class ActualClass implements ToXContentObject, Writeable { private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = @@ -291,30 +291,30 @@ public static class ActualClass implements ToXContentObject, Writeable { static { PARSER.declareString(constructorArg(), ACTUAL_CLASS); PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); - PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT); } private final String actualClass; private final List predictedClasses; - private final long otherClassesCount; + private final long otherPredictedClassCount; - public ActualClass(String actualClass, List predictedClasses, long otherClassesCount) { + public ActualClass(String actualClass, List predictedClasses, long otherPredictedClassCount) { this.actualClass = actualClass; this.predictedClasses = Collections.unmodifiableList(predictedClasses); - this.otherClassesCount = otherClassesCount; + this.otherPredictedClassCount = otherPredictedClassCount; } public ActualClass(StreamInput in) throws IOException { this.actualClass = in.readString(); this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new)); - this.otherClassesCount = in.readLong(); + this.otherPredictedClassCount = in.readLong(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(actualClass); out.writeList(predictedClasses); - out.writeLong(otherClassesCount); + out.writeLong(otherPredictedClassCount); } @Override @@ -322,7 +322,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); - builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount); builder.endObject(); return builder; } @@ -334,12 +334,12 @@ public boolean equals(Object o) { ActualClass that = (ActualClass) o; return Objects.equals(this.actualClass, that.actualClass) && Objects.equals(this.predictedClasses, that.predictedClasses) - && this.otherClassesCount == that.otherClassesCount; + && this.otherPredictedClassCount == that.otherPredictedClassCount; } @Override public int hashCode() { - return Objects.hash(actualClass, predictedClasses, otherClassesCount); + return Objects.hash(actualClass, predictedClasses, otherPredictedClassCount); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index 131a3e71147af..1e7012e3ecdce 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -113,7 +113,7 @@ public void testEvaluate() { List.of( new ActualClass("dog", List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), new ActualClass("cat", List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0)))); - assertThat(result.getOtherClassesCount(), equalTo(0L)); + assertThat(result.getOtherActualClassCount(), equalTo(0L)); } public void testEvaluate_OtherClassesCountGreaterThanZero() { @@ -151,7 +151,7 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { List.of( new ActualClass("dog", List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), new ActualClass("cat", List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); - assertThat(result.getOtherClassesCount(), equalTo(3L)); + assertThat(result.getOtherActualClassCount(), equalTo(3L)); } private static Terms mockTerms(String name, List buckets, long sumOfOtherDocCounts) { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 3e4b45d3a6fbb..1ed11ab59ef00 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -95,7 +95,7 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { new PredictedClass("fox", 3L), new PredictedClass("mouse", 1L)), 0)))); - assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { @@ -155,7 +155,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau new PredictedClass("fox", 3L), new PredictedClass("mouse", 1L)), 0)))); - assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { @@ -184,7 +184,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP new ActualClass("dog", List.of(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)), 8)))); - assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L)); + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); } private static void indexAnimalsData(String indexName) { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 6932f0ff63ef5..f5f05adbdf129 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -619,10 +619,10 @@ setup: } - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ - { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, - { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, - { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], _other_: 0 }]} - - match: { classification.multiclass_confusion_matrix._other_: 0 } + { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, + { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, + { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], other_predicted_class_count: 0 }]} + - match: { classification.multiclass_confusion_matrix.other_actual_class_count: 0 } --- "Test classification multiclass_confusion_matrix with explicit size": - do: @@ -640,9 +640,9 @@ setup: } - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ - { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 } ], _other_: 0 }, - { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 } ], _other_: 0 }]} - - match: { classification.multiclass_confusion_matrix._other_: 1 } + { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 } ], other_predicted_class_count: 0 }, + { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 } ], other_predicted_class_count: 0 }]} + - match: { classification.multiclass_confusion_matrix.other_actual_class_count: 1 } --- "Test classification with null metrics": - do: @@ -659,10 +659,10 @@ setup: } - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ - { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, - { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], _other_: 0 }, - { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], _other_: 0 }]} - - match: { classification.multiclass_confusion_matrix._other_: 0 } + { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, + { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, + { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], other_predicted_class_count: 0 }]} + - match: { classification.multiclass_confusion_matrix.other_actual_class_count: 0 } --- "Test classification given missing actual_field": - do: From c761f63df2382571e60a4a85944cc3ee61426c4a Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 18 Oct 2019 08:52:34 +0200 Subject: [PATCH 4/6] Add actual_class_doc_count field --- .../MulticlassConfusionMatrixMetric.java | 13 +++- .../client/MachineLearningIT.java | 20 ++++-- .../MlClientDocumentationIT.java | 16 +++-- ...classConfusionMatrixMetricResultTests.java | 2 +- .../MulticlassConfusionMatrix.java | 18 +++-- .../MulticlassConfusionMatrixResultTests.java | 2 +- .../MulticlassConfusionMatrixTests.java | 25 ++++--- .../ClassificationEvaluationIT.java | 13 ++++ .../test/ml/evaluate_data_frame.yml | 69 +++++++++++++++---- 9 files changed, 134 insertions(+), 44 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index 3da6bbae37932..292cf0af433ad 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -162,6 +162,7 @@ public int hashCode() { public static class ActualClass implements ToXContentObject { private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count"); @@ -170,20 +171,24 @@ public static class ActualClass implements ToXContentObject { new ConstructingObjectParser<>( "multiclass_confusion_matrix_actual_class", true, - a -> new ActualClass((String) a[0], (List) a[1], (long) a[2])); + a -> new ActualClass((String) a[0], (long) a[1], (List) a[2], (long) a[3])); static { PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT); } private final String actualClass; + private final long actualClassDocCount; private final List predictedClasses; private final long otherPredictedClassCount; - public ActualClass(String actualClass, List predictedClasses, long otherPredictedClassCount) { + public ActualClass( + String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassCount) { this.actualClass = actualClass; + this.actualClassDocCount = actualClassDocCount; this.predictedClasses = Collections.unmodifiableList(predictedClasses); this.otherPredictedClassCount = otherPredictedClassCount; } @@ -192,6 +197,7 @@ public ActualClass(String actualClass, List predictedClasses, lo public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount); builder.endObject(); @@ -204,13 +210,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; ActualClass that = (ActualClass) o; return Objects.equals(this.actualClass, that.actualClass) + && this.actualClassDocCount == that.actualClassDocCount && Objects.equals(this.predictedClasses, that.predictedClasses) && this.otherPredictedClassCount == that.otherPredictedClassCount; } @Override public int hashCode() { - return Objects.hash(actualClass, predictedClasses, otherPredictedClassCount); + return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 3e7a9961e08aa..26987ca7d161e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1804,12 +1804,20 @@ public void testEvaluateDataFrame_Classification() throws IOException { equalTo( List.of( new ActualClass( - "ant", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), 0), + "ant", + 1, + List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), + 0), new ActualClass( - "cat", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), + "cat", + 5, + List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), + 1), new ActualClass( - "dog", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) - ))); + "dog", + 4, + List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), + 0)))); assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L)); } { // Explicit size provided for MulticlassConfusionMatrixMetric metric @@ -1831,8 +1839,8 @@ public void testEvaluateDataFrame_Classification() throws IOException { mcmResult.getConfusionMatrix(), equalTo( List.of( - new ActualClass("cat", List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), - new ActualClass("dog", List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) + new ActualClass("cat", 5, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), + new ActualClass("dog", 4, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) ))); assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L)); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 60eab68faa8d6..c06b537d08927 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3367,12 +3367,20 @@ public void testEvaluateDataFrame_Classification() throws Exception { equalTo( List.of( new ActualClass( - "ant", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), 0), + "ant", + 1, + List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), + 0), new ActualClass( - "cat", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), + "cat", + 4, + List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), + 1), new ActualClass( - "dog", List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) - ))); + "dog", + 4, + List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), + 0)))); assertThat(otherClassesCount, equalTo(0L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java index 4f39fbfbe4f42..afc71cdd13976 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -52,7 +52,7 @@ protected Result createTestInstance() { predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } } - actualClasses.add(new ActualClass(classNames.get(i), predictedClasses, randomNonNegativeLong())); + actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong())); } return new Result(actualClasses, randomNonNegativeLong()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index d04c0e3412cb7..3c69d71c2e7ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -144,6 +144,7 @@ public void process(Aggregations aggs) { List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); for (Filters.Bucket bucket : filtersAgg.getBuckets()) { String actualClass = bucket.getKeyAsString(); + long actualClassDocCount = bucket.getDocCount(); Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); List predictedClasses = new ArrayList<>(); long otherPredictedClassCount = 0; @@ -157,7 +158,7 @@ public void process(Aggregations aggs) { } } predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); - actualClasses.add(new ActualClass(actualClass, predictedClasses, otherPredictedClassCount)); + actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount)); } result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); } @@ -278,6 +279,7 @@ public int hashCode() { public static class ActualClass implements ToXContentObject, Writeable { private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count"); @@ -286,26 +288,31 @@ public static class ActualClass implements ToXContentObject, Writeable { new ConstructingObjectParser<>( "multiclass_confusion_matrix_actual_class", true, - a -> new ActualClass((String) a[0], (List) a[1], (long) a[2])); + a -> new ActualClass((String) a[0], (long) a[1], (List) a[2], (long) a[3])); static { PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT); } private final String actualClass; + private final long actualClassDocCount; private final List predictedClasses; private final long otherPredictedClassCount; - public ActualClass(String actualClass, List predictedClasses, long otherPredictedClassCount) { + public ActualClass( + String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassCount) { this.actualClass = actualClass; + this.actualClassDocCount = actualClassDocCount; this.predictedClasses = Collections.unmodifiableList(predictedClasses); this.otherPredictedClassCount = otherPredictedClassCount; } public ActualClass(StreamInput in) throws IOException { this.actualClass = in.readString(); + this.actualClassDocCount = in.readLong(); this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new)); this.otherPredictedClassCount = in.readLong(); } @@ -313,6 +320,7 @@ public ActualClass(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(actualClass); + out.writeLong(actualClassDocCount); out.writeList(predictedClasses); out.writeLong(otherPredictedClassCount); } @@ -321,6 +329,7 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount); builder.endObject(); @@ -333,13 +342,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; ActualClass that = (ActualClass) o; return Objects.equals(this.actualClass, that.actualClass) + && this.actualClassDocCount == that.actualClassDocCount && Objects.equals(this.predictedClasses, that.predictedClasses) && this.otherPredictedClassCount == that.otherPredictedClassCount; } @Override public int hashCode() { - return Objects.hash(actualClass, predictedClasses, otherPredictedClassCount); + return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java index adddbad9f2497..478ce98473fbc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -32,7 +32,7 @@ public static Result createRandom() { predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } } - actualClasses.add(new ActualClass(classNames.get(i), predictedClasses, randomNonNegativeLong())); + actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong())); } return new Result(actualClasses, randomNonNegativeLong()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index 1e7012e3ecdce..0b4f724549e1a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -91,11 +91,13 @@ public void testEvaluate() { List.of( mockFiltersBucket( "dog", + 30, new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", + 70, new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), @@ -111,8 +113,8 @@ public void testEvaluate() { result.getConfusionMatrix(), equalTo( List.of( - new ActualClass("dog", List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), - new ActualClass("cat", List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0)))); + new ActualClass("dog", 30, List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), + new ActualClass("cat", 70, List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0)))); assertThat(result.getOtherActualClassCount(), equalTo(0L)); } @@ -129,11 +131,13 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { List.of( mockFiltersBucket( "dog", + 30, new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", + 85, new Aggregations(List.of(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), @@ -149,8 +153,8 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { result.getConfusionMatrix(), equalTo( List.of( - new ActualClass("dog", List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), - new ActualClass("cat", List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); + new ActualClass("dog", 30, List.of(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), + new ActualClass("cat", 85, List.of(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); assertThat(result.getOtherActualClassCount(), equalTo(3L)); } @@ -162,9 +166,9 @@ private static Terms mockTerms(String name, List buckets, long sum return aggregation; } - private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) { + private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) { Terms.Bucket bucket = mock(Terms.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(actualClass); + when(bucket.getKeyAsString()).thenReturn(key); when(bucket.getAggregations()).thenReturn(subAggs); return bucket; } @@ -176,16 +180,15 @@ private static Filters mockFilters(String name, List buckets) { return aggregation; } - private static Filters.Bucket mockFiltersBucket(String actualClass, Aggregations subAggs) { - Filters.Bucket bucket = mock(Filters.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(actualClass); + private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) { + Filters.Bucket bucket = mockFiltersBucket(key, docCount); when(bucket.getAggregations()).thenReturn(subAggs); return bucket; } - private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) { + private static Filters.Bucket mockFiltersBucket(String key, long docCount) { Filters.Bucket bucket = mock(Filters.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(predictedClass); + when(bucket.getKeyAsString()).thenReturn(key); when(bucket.getDocCount()).thenReturn(docCount); return bucket; } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 1ed11ab59ef00..196ca87fb1213 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -56,6 +56,7 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { confusionMatrixResult.getConfusionMatrix(), equalTo(List.of( new ActualClass("ant", + 15, List.of( new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), @@ -64,6 +65,7 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { new PredictedClass("mouse", 5L)), 0), new ActualClass("cat", + 15, List.of( new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), @@ -72,6 +74,7 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { new PredictedClass("mouse", 2L)), 0), new ActualClass("dog", + 15, List.of( new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), @@ -80,6 +83,7 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { new PredictedClass("mouse", 3L)), 0), new ActualClass("fox", + 15, List.of( new PredictedClass("ant", 5L), new PredictedClass("cat", 3L), @@ -88,6 +92,7 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { new PredictedClass("mouse", 4L)), 0), new ActualClass("mouse", + 15, List.of( new PredictedClass("ant", 2L), new PredictedClass("cat", 5L), @@ -116,6 +121,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau confusionMatrixResult.getConfusionMatrix(), equalTo(List.of( new ActualClass("ant", + 15, List.of( new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), @@ -124,6 +130,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau new PredictedClass("mouse", 5L)), 0), new ActualClass("cat", + 15, List.of( new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), @@ -132,6 +139,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau new PredictedClass("mouse", 2L)), 0), new ActualClass("dog", + 15, List.of( new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), @@ -140,6 +148,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau new PredictedClass("mouse", 3L)), 0), new ActualClass("fox", + 15, List.of( new PredictedClass("ant", 5L), new PredictedClass("cat", 3L), @@ -148,6 +157,7 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau new PredictedClass("mouse", 4L)), 0), new ActualClass("mouse", + 15, List.of( new PredictedClass("ant", 2L), new PredictedClass("cat", 5L), @@ -176,12 +186,15 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP confusionMatrixResult.getConfusionMatrix(), equalTo(List.of( new ActualClass("ant", + 15, List.of(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)), 7), new ActualClass("cat", + 15, List.of(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)), 6), new ActualClass("dog", + 15, List.of(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)), 8)))); assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index f5f05adbdf129..6e03108c7560b 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -618,11 +618,40 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ - { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, - { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, - { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], other_predicted_class_count: 0 }]} - - match: { classification.multiclass_confusion_matrix.other_actual_class_count: 0 } + - match: + classification.multiclass_confusion_matrix: + confusion_matrix: + - actual_class: "cat" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 2 + - predicted_class: "dog" + count: 1 + - predicted_class: "mouse" + count: 0 + other_predicted_class_count: 0 + - actual_class: "dog" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 1 + - predicted_class: "dog" + count: 2 + - predicted_class: "mouse" + count: 0 + other_predicted_class_count: 0 + - actual_class: "mouse" + actual_class_doc_count: 2 + predicted_classes: + - predicted_class: "cat" + count: 1 + - predicted_class: "dog" + count: 0 + - predicted_class: "mouse" + count: 1 + other_predicted_class_count: 0 + other_actual_class_count: 0 --- "Test classification multiclass_confusion_matrix with explicit size": - do: @@ -639,10 +668,26 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ - { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 } ], other_predicted_class_count: 0 }, - { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 } ], other_predicted_class_count: 0 }]} - - match: { classification.multiclass_confusion_matrix.other_actual_class_count: 1 } + - match: + classification.multiclass_confusion_matrix: + confusion_matrix: + - actual_class: "cat" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 2 + - predicted_class: "dog" + count: 1 + other_predicted_class_count: 0 + - actual_class: "dog" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 1 + - predicted_class: "dog" + count: 2 + other_predicted_class_count: 0 + other_actual_class_count: 1 --- "Test classification with null metrics": - do: @@ -658,11 +703,7 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: [ - { actual_class: "cat", predicted_classes: [ { predicted_class: "cat", count: 2 }, { predicted_class: "dog", count: 1 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, - { actual_class: "dog", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 2 }, { predicted_class: "mouse", count: 0 } ], other_predicted_class_count: 0 }, - { actual_class: "mouse", predicted_classes: [ { predicted_class: "cat", count: 1 }, { predicted_class: "dog", count: 0 }, { predicted_class: "mouse", count: 1 } ], other_predicted_class_count: 0 }]} - - match: { classification.multiclass_confusion_matrix.other_actual_class_count: 0 } + - is_true: classification.multiclass_confusion_matrix --- "Test classification given missing actual_field": - do: From 99a8c3222adc4de56ccf68f8009d1351a7bdc902 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 18 Oct 2019 10:42:14 +0200 Subject: [PATCH 5/6] Rename other_predicted_class_count to other_predicted_class_doc_count --- .../MulticlassConfusionMatrixMetric.java | 22 ++++++++----- .../MlClientDocumentationIT.java | 2 +- ...classConfusionMatrixMetricResultTests.java | 4 +-- .../MulticlassConfusionMatrix.java | 33 +++++++++++-------- .../MulticlassConfusionMatrixResultTests.java | 4 +-- .../test/ml/evaluate_data_frame.yml | 10 +++--- 6 files changed, 41 insertions(+), 34 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index 292cf0af433ad..b8c525430f2ba 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -21,6 +21,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -164,7 +165,7 @@ public static class ActualClass implements ToXContentObject { private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); - private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count"); + private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = @@ -177,20 +178,20 @@ public static class ActualClass implements ToXContentObject { PARSER.declareString(constructorArg(), ACTUAL_CLASS); PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); - PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT); + PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT); } private final String actualClass; private final long actualClassDocCount; private final List predictedClasses; - private final long otherPredictedClassCount; + private final long otherPredictedClassDocCount; public ActualClass( - String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassCount) { + String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassDocCount) { this.actualClass = actualClass; this.actualClassDocCount = actualClassDocCount; this.predictedClasses = Collections.unmodifiableList(predictedClasses); - this.otherPredictedClassCount = otherPredictedClassCount; + this.otherPredictedClassDocCount = otherPredictedClassDocCount; } @Override @@ -199,7 +200,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); - builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount); + builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount); builder.endObject(); return builder; } @@ -212,12 +213,17 @@ public boolean equals(Object o) { return Objects.equals(this.actualClass, that.actualClass) && this.actualClassDocCount == that.actualClassDocCount && Objects.equals(this.predictedClasses, that.predictedClasses) - && this.otherPredictedClassCount == that.otherPredictedClassCount; + && this.otherPredictedClassDocCount == that.otherPredictedClassDocCount; } @Override public int hashCode() { - return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount); + return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount); + } + + @Override + public String toString() { + return Strings.toString(this); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index c06b537d08927..45d08fdeef92a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3373,7 +3373,7 @@ public void testEvaluateDataFrame_Classification() throws Exception { 0), new ActualClass( "cat", - 4, + 5, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), new ActualClass( diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java index afc71cdd13976..25d145ecd7cf9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -48,9 +48,7 @@ protected Result createTestInstance() { for (int i = 0; i < numClasses; i++) { List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { - if (randomBoolean()) { - predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); - } + predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong())); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 3c69d71c2e7ad..f1a310967d7e3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -147,18 +147,18 @@ public void process(Aggregations aggs) { long actualClassDocCount = bucket.getDocCount(); Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); List predictedClasses = new ArrayList<>(); - long otherPredictedClassCount = 0; + long otherPredictedClassDocCount = 0; for (Filters.Bucket subBucket : subAgg.getBuckets()) { String predictedClass = subBucket.getKeyAsString(); long docCount = subBucket.getDocCount(); if (OTHER_BUCKET_KEY.equals(predictedClass)) { - otherPredictedClassCount = docCount; + otherPredictedClassDocCount = docCount; } else { predictedClasses.add(new PredictedClass(predictedClass, docCount)); } } predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); - actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount)); + actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); } result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); } @@ -214,8 +214,9 @@ public static Result fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - // Immutable + /** List of actual classes. */ private final List actualClasses; + /** Number of actual classes that were not included in the confusion matrix because there were too many of them. */ private final long otherActualClassCount; public Result(List actualClasses, long otherActualClassCount) { @@ -281,7 +282,7 @@ public static class ActualClass implements ToXContentObject, Writeable { private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); - private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count"); + private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = @@ -294,27 +295,31 @@ public static class ActualClass implements ToXContentObject, Writeable { PARSER.declareString(constructorArg(), ACTUAL_CLASS); PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); - PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT); + PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT); } + /** Name of the actual class. */ private final String actualClass; + /** Number of documents (examples) belonging to the {code actualClass} class. */ private final long actualClassDocCount; + /** List of predicted classes. */ private final List predictedClasses; - private final long otherPredictedClassCount; + /** Number of documents that were not predicted as any of the {@code predictedClasses}. */ + private final long otherPredictedClassDocCount; public ActualClass( - String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassCount) { + String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassDocCount) { this.actualClass = actualClass; this.actualClassDocCount = actualClassDocCount; this.predictedClasses = Collections.unmodifiableList(predictedClasses); - this.otherPredictedClassCount = otherPredictedClassCount; + this.otherPredictedClassDocCount = otherPredictedClassDocCount; } public ActualClass(StreamInput in) throws IOException { this.actualClass = in.readString(); this.actualClassDocCount = in.readLong(); this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new)); - this.otherPredictedClassCount = in.readLong(); + this.otherPredictedClassDocCount = in.readLong(); } @Override @@ -322,7 +327,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(actualClass); out.writeLong(actualClassDocCount); out.writeList(predictedClasses); - out.writeLong(otherPredictedClassCount); + out.writeLong(otherPredictedClassDocCount); } @Override @@ -331,7 +336,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); - builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount); + builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount); builder.endObject(); return builder; } @@ -344,12 +349,12 @@ public boolean equals(Object o) { return Objects.equals(this.actualClass, that.actualClass) && this.actualClassDocCount == that.actualClassDocCount && Objects.equals(this.predictedClasses, that.predictedClasses) - && this.otherPredictedClassCount == that.otherPredictedClassCount; + && this.otherPredictedClassDocCount == that.otherPredictedClassDocCount; } @Override public int hashCode() { - return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount); + return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java index 478ce98473fbc..a4b39cbfbe7f0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -28,9 +28,7 @@ public static Result createRandom() { for (int i = 0; i < numClasses; i++) { List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { - if (randomBoolean()) { - predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); - } + predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong())); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 6e03108c7560b..f35346fc78582 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -630,7 +630,7 @@ setup: count: 1 - predicted_class: "mouse" count: 0 - other_predicted_class_count: 0 + other_predicted_class_doc_count: 0 - actual_class: "dog" actual_class_doc_count: 3 predicted_classes: @@ -640,7 +640,7 @@ setup: count: 2 - predicted_class: "mouse" count: 0 - other_predicted_class_count: 0 + other_predicted_class_doc_count: 0 - actual_class: "mouse" actual_class_doc_count: 2 predicted_classes: @@ -650,7 +650,7 @@ setup: count: 0 - predicted_class: "mouse" count: 1 - other_predicted_class_count: 0 + other_predicted_class_doc_count: 0 other_actual_class_count: 0 --- "Test classification multiclass_confusion_matrix with explicit size": @@ -678,7 +678,7 @@ setup: count: 2 - predicted_class: "dog" count: 1 - other_predicted_class_count: 0 + other_predicted_class_doc_count: 0 - actual_class: "dog" actual_class_doc_count: 3 predicted_classes: @@ -686,7 +686,7 @@ setup: count: 1 - predicted_class: "dog" count: 2 - other_predicted_class_count: 0 + other_predicted_class_doc_count: 0 other_actual_class_count: 1 --- "Test classification with null metrics": From 92a6dd4b3a8dc42fef8e3d31620ee7bd3c0fd74e Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Mon, 21 Oct 2019 09:57:44 +0200 Subject: [PATCH 6/6] Apply review comments --- .../MulticlassConfusionMatrixMetric.java | 84 +++++++++++-------- .../client/MachineLearningIT.java | 16 ++-- .../MlClientDocumentationIT.java | 12 +-- ...classConfusionMatrixMetricResultTests.java | 11 ++- .../MulticlassConfusionMatrix.java | 39 +++++---- .../MulticlassConfusionMatrixResultTests.java | 67 +++++++++++++++ 6 files changed, 162 insertions(+), 67 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index b8c525430f2ba..7199660e94d0c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -32,7 +32,6 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; /** @@ -103,23 +102,22 @@ public static class Result implements EvaluationMetric.Result { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "multiclass_confusion_matrix_result", true, a -> new Result((List) a[0], (long) a[1])); + "multiclass_confusion_matrix_result", true, a -> new Result((List) a[0], (Long) a[1])); static { - PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); - PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT); + PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); + PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT); } public static Result fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - // Immutable private final List confusionMatrix; - private final long otherActualClassCount; + private final Long otherActualClassCount; - public Result(List confusionMatrix, long otherActualClassCount) { - this.confusionMatrix = Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)); + public Result(@Nullable List confusionMatrix, @Nullable Long otherActualClassCount) { + this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null; this.otherActualClassCount = otherActualClassCount; } @@ -132,15 +130,19 @@ public List getConfusionMatrix() { return confusionMatrix; } - public long getOtherActualClassCount() { + public Long getOtherActualClassCount() { return otherActualClassCount; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); - builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount); + if (confusionMatrix != null) { + builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + } + if (otherActualClassCount != null) { + builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount); + } builder.endObject(); return builder; } @@ -151,7 +153,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; return Objects.equals(this.confusionMatrix, that.confusionMatrix) - && this.otherActualClassCount == that.otherActualClassCount; + && Objects.equals(this.otherActualClassCount, that.otherActualClassCount); } @Override @@ -172,35 +174,45 @@ public static class ActualClass implements ToXContentObject { new ConstructingObjectParser<>( "multiclass_confusion_matrix_actual_class", true, - a -> new ActualClass((String) a[0], (long) a[1], (List) a[2], (long) a[3])); + a -> new ActualClass((String) a[0], (Long) a[1], (List) a[2], (Long) a[3])); static { - PARSER.declareString(constructorArg(), ACTUAL_CLASS); - PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); - PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); - PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT); + PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS); + PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); + PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT); } private final String actualClass; - private final long actualClassDocCount; + private final Long actualClassDocCount; private final List predictedClasses; - private final long otherPredictedClassDocCount; + private final Long otherPredictedClassDocCount; - public ActualClass( - String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassDocCount) { + public ActualClass(@Nullable String actualClass, + @Nullable Long actualClassDocCount, + @Nullable List predictedClasses, + @Nullable Long otherPredictedClassDocCount) { this.actualClass = actualClass; this.actualClassDocCount = actualClassDocCount; - this.predictedClasses = Collections.unmodifiableList(predictedClasses); + this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null; this.otherPredictedClassDocCount = otherPredictedClassDocCount; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); - builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); - builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); - builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount); + if (actualClass != null) { + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + } + if (actualClassDocCount != null) { + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + } + if (predictedClasses != null) { + builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); + } + if (otherPredictedClassDocCount != null) { + builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount); + } builder.endObject(); return builder; } @@ -211,9 +223,9 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; ActualClass that = (ActualClass) o; return Objects.equals(this.actualClass, that.actualClass) - && this.actualClassDocCount == that.actualClassDocCount + && Objects.equals(this.actualClassDocCount, that.actualClassDocCount) && Objects.equals(this.predictedClasses, that.predictedClasses) - && this.otherPredictedClassDocCount == that.otherPredictedClassDocCount; + && Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount); } @Override @@ -235,17 +247,17 @@ public static class PredictedClass implements ToXContentObject { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1])); + "multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1])); static { - PARSER.declareString(constructorArg(), PREDICTED_CLASS); - PARSER.declareLong(constructorArg(), COUNT); + PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS); + PARSER.declareLong(optionalConstructorArg(), COUNT); } private final String predictedClass; private final Long count; - public PredictedClass(String predictedClass, Long count) { + public PredictedClass(@Nullable String predictedClass, @Nullable Long count) { this.predictedClass = predictedClass; this.count = count; } @@ -253,8 +265,12 @@ public PredictedClass(String predictedClass, Long count) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass); - builder.field(COUNT.getPreferredName(), count); + if (predictedClass != null) { + builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass); + } + if (count != null) { + builder.field(COUNT.getPreferredName(), count); + } builder.endObject(); return builder; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 26987ca7d161e..d48c07ede60e5 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1805,19 +1805,19 @@ public void testEvaluateDataFrame_Classification() throws IOException { List.of( new ActualClass( "ant", - 1, + 1L, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), - 0), + 0L), new ActualClass( "cat", - 5, + 5L, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), - 1), + 1L), new ActualClass( "dog", - 4, + 4L, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), - 0)))); + 0L)))); assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L)); } { // Explicit size provided for MulticlassConfusionMatrixMetric metric @@ -1839,8 +1839,8 @@ public void testEvaluateDataFrame_Classification() throws IOException { mcmResult.getConfusionMatrix(), equalTo( List.of( - new ActualClass("cat", 5, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1), - new ActualClass("dog", 4, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0) + new ActualClass("cat", 5L, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L), + new ActualClass("dog", 4L, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L) ))); assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L)); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 45d08fdeef92a..36947fd1b3b99 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3368,19 +3368,19 @@ public void testEvaluateDataFrame_Classification() throws Exception { List.of( new ActualClass( "ant", - 1, + 1L, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), - 0), + 0L), new ActualClass( "cat", - 5, + 5L, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), - 1), + 1L), new ActualClass( "dog", - 4, + 4L, List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), - 0)))); + 0L)))); assertThat(otherClassesCount, equalTo(0L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java index 25d145ecd7cf9..55b74eb94ea21 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -48,11 +48,16 @@ protected Result createTestInstance() { for (int i = 0; i < numClasses; i++) { List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { - predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); + predictedClasses.add(new PredictedClass(classNames.get(j), randomBoolean() ? randomNonNegativeLong() : null)); } - actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong())); + actualClasses.add( + new ActualClass( + classNames.get(i), + randomBoolean() ? randomNonNegativeLong() : null, + predictedClasses, + randomBoolean() ? randomNonNegativeLong() : null)); } - return new Result(actualClasses, randomNonNegativeLong()); + return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index f1a310967d7e3..9f0150c5b8fe6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -220,13 +220,13 @@ public static Result fromXContent(XContentParser parser) { private final long otherActualClassCount; public Result(List actualClasses, long otherActualClassCount) { - this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); - this.otherActualClassCount = otherActualClassCount; + this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX)); + this.otherActualClassCount = requireNonNegative(otherActualClassCount, OTHER_ACTUAL_CLASS_COUNT); } public Result(StreamInput in) throws IOException { this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); - this.otherActualClassCount = in.readLong(); + this.otherActualClassCount = in.readVLong(); } @Override @@ -250,7 +250,7 @@ public long getOtherActualClassCount() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeList(actualClasses); - out.writeLong(otherActualClassCount); + out.writeVLong(otherActualClassCount); } @Override @@ -309,25 +309,25 @@ public static class ActualClass implements ToXContentObject, Writeable { public ActualClass( String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassDocCount) { - this.actualClass = actualClass; - this.actualClassDocCount = actualClassDocCount; - this.predictedClasses = Collections.unmodifiableList(predictedClasses); - this.otherPredictedClassDocCount = otherPredictedClassDocCount; + this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS); + this.actualClassDocCount = requireNonNegative(actualClassDocCount, ACTUAL_CLASS_DOC_COUNT); + this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES)); + this.otherPredictedClassDocCount = requireNonNegative(otherPredictedClassDocCount, OTHER_PREDICTED_CLASS_DOC_COUNT); } public ActualClass(StreamInput in) throws IOException { this.actualClass = in.readString(); - this.actualClassDocCount = in.readLong(); + this.actualClassDocCount = in.readVLong(); this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new)); - this.otherPredictedClassDocCount = in.readLong(); + this.otherPredictedClassDocCount = in.readVLong(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(actualClass); - out.writeLong(actualClassDocCount); + out.writeVLong(actualClassDocCount); out.writeList(predictedClasses); - out.writeLong(otherPredictedClassDocCount); + out.writeVLong(otherPredictedClassDocCount); } @Override @@ -377,13 +377,13 @@ public static class PredictedClass implements ToXContentObject, Writeable { private final long count; public PredictedClass(String predictedClass, long count) { - this.predictedClass = predictedClass; - this.count = count; + this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS); + this.count = requireNonNegative(count, COUNT); } public PredictedClass(StreamInput in) throws IOException { this.predictedClass = in.readString(); - this.count = in.readLong(); + this.count = in.readVLong(); } public String getPredictedClass() { @@ -393,7 +393,7 @@ public String getPredictedClass() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(predictedClass); - out.writeLong(count); + out.writeVLong(count); } @Override @@ -419,4 +419,11 @@ public int hashCode() { return Objects.hash(predictedClass, count); } } + + private static long requireNonNegative(long value, ParseField field) { + if (value < 0) { + throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value); + } + return value; + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java index a4b39cbfbe7f0..a2c30eaeb4979 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; @@ -14,11 +15,14 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.Matchers.equalTo; + public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { public static Result createRandom() { @@ -60,4 +64,67 @@ protected Predicate getRandomFieldsExcludeFilter() { // allow unknown fields in the root of the object only return field -> !field.isEmpty(); } + + public void testConstructor_ValidationFailures() { + { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new Result(null, 0)); + assertThat(e.getMessage(), equalTo("[confusion_matrix] must not be null.")); + } + { + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new Result(Collections.emptyList(), -1)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[other_actual_class_count] must be >= 0, was: -1")); + } + { + IllegalArgumentException e = + expectThrows( + IllegalArgumentException.class, + () -> new Result(Collections.singletonList(new ActualClass(null, 0, Collections.emptyList(), 0)), 0)); + assertThat(e.getMessage(), equalTo("[actual_class] must not be null.")); + } + { + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> new Result(Collections.singletonList(new ActualClass("actual_class", -1, Collections.emptyList(), 0)), 0)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[actual_class_doc_count] must be >= 0, was: -1")); + } + { + IllegalArgumentException e = + expectThrows( + IllegalArgumentException.class, + () -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, null, 0)), 0)); + assertThat(e.getMessage(), equalTo("[predicted_classes] must not be null.")); + } + { + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, Collections.emptyList(), -1)), 0)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[other_predicted_class_doc_count] must be >= 0, was: -1")); + } + { + IllegalArgumentException e = + expectThrows( + IllegalArgumentException.class, + () -> new Result( + Collections.singletonList( + new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass(null, 0)), 0)), + 0)); + assertThat(e.getMessage(), equalTo("[predicted_class] must not be null.")); + } + { + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> new Result( + Collections.singletonList( + new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass("predicted_class", -1)), 0)), + 0)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[count] must be >= 0, was: -1")); + } + } }