From 0f01e34510f83d98f26e2b54f80d019ad2b47390 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 13 Dec 2019 13:22:52 +0100 Subject: [PATCH 1/7] Fix accuracy metric --- .../classification/AccuracyMetric.java | 97 ++++---- .../client/MachineLearningIT.java | 14 +- .../AccuracyMetricResultTests.java | 8 +- .../evaluation/EvaluationMetric.java | 2 +- .../evaluation/classification/Accuracy.java | 218 +++++++++++------- .../MulticlassConfusionMatrix.java | 60 +++-- .../classification/AccuracyResultTests.java | 18 +- .../classification/AccuracyTests.java | 80 +++---- .../MulticlassConfusionMatrixTests.java | 19 +- .../ClassificationEvaluationIT.java | 66 ++++-- .../ml/integration/ClassificationIT.java | 10 +- .../test/ml/evaluate_data_frame.yml | 17 +- 12 files changed, 353 insertions(+), 256 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java index 4db165be06caa..151783499e46b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java @@ -20,6 +20,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; @@ -35,10 +36,25 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; /** - * {@link AccuracyMetric} is a metric that answers the question: - * "What fraction of examples have been classified correctly by the classifier?" + * {@link AccuracyMetric} is a metric that answers the following two questions: * - * equation: accuracy = 1/n * Σ(y == y´) + * 1. What is the fraction of documents for which predicted class equals the actual class? + * + * equation: overall_accuracy = 1/n * Σ(y == y') + * where: n = total number of documents + * y = document's actual class + * y' = document's predicted class + * + * 2. For any given class X, what is the fraction of documents for which either + * a) both actual and predicted class are equal to X (true positives) + * or + * b) both actual and predicted class are not equal to X (true negatives) + * + * equation: accuracy(X) = 1/n * (TP(X) + TN(X)) + * where: X = class being examined + * n = total number of documents + * TP(X) = number of true positives wrt X + * TN(X) = number of true negatives wrt X */ public class AccuracyMetric implements EvaluationMetric { @@ -78,15 +94,15 @@ public int hashCode() { public static class Result implements EvaluationMetric.Result { - private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); + private static final ParseField CLASSES = new ParseField("classes"); private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); } @@ -94,13 +110,13 @@ public static Result fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - /** List of actual classes. */ - private final List actualClasses; - /** Fraction of documents predicted correctly. */ + /** List of per-class results. */ + private final List classes; + /** Fraction of documents for which predicted class equals the actual class. */ private final double overallAccuracy; - public Result(List actualClasses, double overallAccuracy) { - this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); + public Result(List classes, double overallAccuracy) { + this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); this.overallAccuracy = overallAccuracy; } @@ -109,8 +125,8 @@ public String getMetricName() { return NAME; } - public List getActualClasses() { - return actualClasses; + public List getClasses() { + return classes; } public double getOverallAccuracy() { @@ -120,7 +136,7 @@ public double getOverallAccuracy() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); + builder.field(CLASSES.getPreferredName(), classes); builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); builder.endObject(); return builder; @@ -131,52 +147,42 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return Objects.equals(this.actualClasses, that.actualClasses) + return Objects.equals(this.classes, that.classes) && this.overallAccuracy == that.overallAccuracy; } @Override public int hashCode() { - return Objects.hash(actualClasses, overallAccuracy); + return Objects.hash(classes, overallAccuracy); } } - public static class ActualClass implements ToXContentObject { + public static class PerClassResult implements ToXContentObject { - private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); - private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField CLASS_NAME = new ParseField("class_name"); private static final ParseField ACCURACY = new ParseField("accuracy"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); static { - PARSER.declareString(constructorArg(), ACTUAL_CLASS); - PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareString(constructorArg(), CLASS_NAME); PARSER.declareDouble(constructorArg(), ACCURACY); } - /** Name of the actual class. */ - private final String actualClass; - /** Number of documents (examples) belonging to the {code actualClass} class. */ - private final long actualClassDocCount; - /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */ + /** Name of the class. */ + private final String className; + /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */ private final double accuracy; - public ActualClass( - String actualClass, long actualClassDocCount, double accuracy) { - this.actualClass = Objects.requireNonNull(actualClass); - this.actualClassDocCount = actualClassDocCount; + public PerClassResult(String className, double accuracy) { + this.className = Objects.requireNonNull(className); this.accuracy = accuracy; } - public String getActualClass() { - return actualClass; - } - - public long getActualClassDocCount() { - return actualClassDocCount; + public String getClassName() { + return className; } public double getAccuracy() { @@ -186,8 +192,7 @@ public double getAccuracy() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); - builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(CLASS_NAME.getPreferredName(), className); builder.field(ACCURACY.getPreferredName(), accuracy); builder.endObject(); return builder; @@ -197,15 +202,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ActualClass that = (ActualClass) o; - return Objects.equals(this.actualClass, that.actualClass) - && this.actualClassDocCount == that.actualClassDocCount + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) && this.accuracy == that.accuracy; } @Override public int hashCode() { - return Objects.hash(actualClass, actualClassDocCount, accuracy); + return Objects.hash(className, accuracy); + } + + @Override + public String toString() { + return Strings.toString(this); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 4aee48dff13a5..443c337d089fa 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1819,15 +1819,15 @@ public void testEvaluateDataFrame_Classification() throws IOException { AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME); assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat( - accuracyResult.getActualClasses(), + accuracyResult.getClasses(), equalTo( List.of( - // 3 out of 5 examples labeled as "cat" were classified correctly - new AccuracyMetric.ActualClass("cat", 5, 0.6), - // 3 out of 4 examples labeled as "dog" were classified correctly - new AccuracyMetric.ActualClass("dog", 4, 0.75), - // no examples labeled as "ant" were classified correctly - new AccuracyMetric.ActualClass("ant", 1, 0.0)))); + // 9 out of 10 examples were classified correctly + new AccuracyMetric.PerClassResult("ant", 0.9), + // 6 out of 10 examples were classified correctly + new AccuracyMetric.PerClassResult("cat", 0.6), + // 8 out of 10 examples were classified correctly + new AccuracyMetric.PerClassResult("dog", 0.8)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly } { // Precision diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java index df48ef3123dd1..8758cea86c451 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java @@ -19,7 +19,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -41,13 +41,13 @@ protected NamedXContentRegistry xContentRegistry() { public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List actualClasses = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double accuracy = randomDoubleBetween(0.0, 1.0, true); - actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); + classes.add(new PerClassResult(classNames.get(i), accuracy)); } double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); - return new Result(actualClasses, overallAccuracy); + return new Result(classes, overallAccuracy); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java index 36bf7634cb43f..8a106175ace91 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable { * Gets the evaluation result for this metric. * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise */ - Optional getResult(); + Optional getResult(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 8e7b8b6066932..4e799648f99d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; -import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; @@ -39,22 +39,36 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** - * {@link Accuracy} is a metric that answers the question: - * "What fraction of examples have been classified correctly by the classifier?" + * {@link Accuracy} is a metric that answers the following two questions: * - * equation: accuracy = 1/n * Σ(y == y´) + * 1. What is the fraction of documents for which predicted class equals the actual class? + * + * equation: overall_accuracy = 1/n * Σ(y == y') + * where: n = total number of documents + * y = document's actual class + * y' = document's predicted class + * + * 2. For any given class X, what is the fraction of documents for which either + * a) both actual and predicted class are equal to X (true positives) + * or + * b) both actual and predicted class are not equal to X (true negatives) + * + * equation: accuracy(X) = 1/n * (TP(X) + TN(X)) + * where: X = class being examined + * n = total number of documents + * TP(X) = number of true positives wrt X + * TN(X) = number of true negatives wrt X */ public class Accuracy implements EvaluationMetric { public static final ParseField NAME = new ParseField("accuracy"); + static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; + private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; - private static final String CLASSES_AGG_NAME = "classification_classes"; - private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy"; - private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; - private static String buildScript(Object...args) { - return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + private static Script buildScript(Object...args) { + return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); } private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new); @@ -63,11 +77,28 @@ public static Accuracy fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private EvaluationMetricResult result; + private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; - public Accuracy() {} + private final int maxClassesCardinality; + private final MulticlassConfusionMatrix matrix; + private String actualField; + private Double overallAccuracy; + private Result result; - public Accuracy(StreamInput in) throws IOException {} + public Accuracy() { + this((Integer) null); + } + + // Visible for testing + public Accuracy(@Nullable Integer maxClassesCardinality) { + this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; + this.matrix = new MulticlassConfusionMatrix(this.maxClassesCardinality, NAME.getPreferredName() + "_"); + } + + public Accuracy(StreamInput in) throws IOException { + this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; + this.matrix = new MulticlassConfusionMatrix(in); + } @Override public String getWriteableName() { @@ -81,43 +112,79 @@ public String getName() { @Override public final Tuple, List> aggs(String actualField, String predictedField) { - if (result != null) { - return Tuple.tuple(List.of(), List.of()); + // Store given {@code actualField} for the purpose of generating error message in {@code process}. + this.actualField = actualField; + List aggs = new ArrayList<>(); + List pipelineAggs = new ArrayList<>(); + if (overallAccuracy == null) { + aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); + } + if (result == null) { + Tuple, List> matrixAggs = matrix.aggs(actualField, predictedField); + aggs.addAll(matrixAggs.v1()); + pipelineAggs.addAll(matrixAggs.v2()); } - Script accuracyScript = new Script(buildScript(actualField, predictedField)); - return Tuple.tuple( - List.of( - AggregationBuilders.terms(CLASSES_AGG_NAME) - .field(actualField) - .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)), - AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)), - List.of()); + return Tuple.tuple(aggs, pipelineAggs); } @Override public void process(Aggregations aggs) { - if (result != null) { - return; + if (overallAccuracy == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { + NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); + overallAccuracy = overallAccuracyAgg.value(); } - Terms classesAgg = aggs.get(CLASSES_AGG_NAME); - NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); - List actualClasses = new ArrayList<>(classesAgg.getBuckets().size()); - for (Terms.Bucket bucket : classesAgg.getBuckets()) { - String actualClass = bucket.getKeyAsString(); - long actualClassDocCount = bucket.getDocCount(); - NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME); - actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value())); + matrix.process(aggs); + if (result == null && matrix.getResult().isPresent()) { + if (matrix.getResult().get().getOtherActualClassCount() > 0) { + // This means there were more than {@code maxClassesCardinality} buckets. + // We cannot calculate per-class accuracy accurately, so we fail. + throw ExceptionsHelper.badRequestException( + "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField); + } + result = new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy); } - result = new Result(actualClasses, overallAccuracyAgg.value()); } @Override - public Optional getResult() { + public Optional getResult() { return Optional.ofNullable(result); } + /** + * Computes the per-class accuracy results based on multiclass confusion matrix's result. + * Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension. + * This method is visible for testing only. + */ + static List computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) { + assert matrixResult.getOtherActualClassCount() == 0; + // Number of actual classes taken into account + int n = matrixResult.getConfusionMatrix().size(); + // Total number of documents taken into account + long totalDocCount = + matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum(); + List classes = new ArrayList<>(n); + for (int i = 0; i < n; ++i) { + String className = matrixResult.getConfusionMatrix().get(i).getActualClass(); + // Start with the assumption that all the docs were predicted correctly. + long correctDocCount = totalDocCount; + for (int j = 0; j < n; ++j) { + if (i != j) { + // Subtract errors (false negatives) + correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount(); + // Subtract errors (false positives) + correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount(); + } + } + // Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix + correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount(); + classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount)); + } + return classes; + } + @Override public void writeTo(StreamOutput out) throws IOException { + matrix.writeTo(out); } @Override @@ -131,25 +198,26 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - return true; + Accuracy that = (Accuracy) o; + return Objects.equals(this.matrix, that.matrix); } @Override public int hashCode() { - return Objects.hashCode(NAME.getPreferredName()); + return Objects.hash(matrix); } public static class Result implements EvaluationMetricResult { - private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); + private static final ParseField CLASSES = new ParseField("classes"); private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); } @@ -157,18 +225,18 @@ public static Result fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - /** List of actual classes. */ - private final List actualClasses; - /** Fraction of documents predicted correctly. */ + /** List of per-class results. */ + private final List classes; + /** Fraction of documents for which predicted class equals the actual class. */ private final double overallAccuracy; - public Result(List actualClasses, double overallAccuracy) { - this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES)); + public Result(List classes, double overallAccuracy) { + this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); this.overallAccuracy = overallAccuracy; } public Result(StreamInput in) throws IOException { - this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); + this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); this.overallAccuracy = in.readDouble(); } @@ -182,8 +250,8 @@ public String getMetricName() { return NAME.getPreferredName(); } - public List getActualClasses() { - return actualClasses; + public List getClasses() { + return classes; } public double getOverallAccuracy() { @@ -192,14 +260,14 @@ public double getOverallAccuracy() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeList(actualClasses); + out.writeList(classes); out.writeDouble(overallAccuracy); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); + builder.field(CLASSES.getPreferredName(), classes); builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); builder.endObject(); return builder; @@ -210,54 +278,47 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return Objects.equals(this.actualClasses, that.actualClasses) + return Objects.equals(this.classes, that.classes) && this.overallAccuracy == that.overallAccuracy; } @Override public int hashCode() { - return Objects.hash(actualClasses, overallAccuracy); + return Objects.hash(classes, overallAccuracy); } } - public static class ActualClass implements ToXContentObject, Writeable { + public static class PerClassResult implements ToXContentObject, Writeable { - private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); - private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField CLASS_NAME = new ParseField("class_name"); private static final ParseField ACCURACY = new ParseField("accuracy"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); static { - PARSER.declareString(constructorArg(), ACTUAL_CLASS); - PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareString(constructorArg(), CLASS_NAME); PARSER.declareDouble(constructorArg(), ACCURACY); } - /** Name of the actual class. */ - private final String actualClass; - /** Number of documents (examples) belonging to the {code actualClass} class. */ - private final long actualClassDocCount; - /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */ + /** Name of the class. */ + private final String className; + /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */ private final double accuracy; - public ActualClass( - String actualClass, long actualClassDocCount, double accuracy) { - this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS); - this.actualClassDocCount = actualClassDocCount; + public PerClassResult(String className, double accuracy) { + this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); this.accuracy = accuracy; } - public ActualClass(StreamInput in) throws IOException { - this.actualClass = in.readString(); - this.actualClassDocCount = in.readVLong(); + public PerClassResult(StreamInput in) throws IOException { + this.className = in.readString(); this.accuracy = in.readDouble(); } - public String getActualClass() { - return actualClass; + public String getClassName() { + return className; } public double getAccuracy() { @@ -266,16 +327,14 @@ public double getAccuracy() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(actualClass); - out.writeVLong(actualClassDocCount); + out.writeString(className); out.writeDouble(accuracy); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); - builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(CLASS_NAME.getPreferredName(), className); builder.field(ACCURACY.getPreferredName(), accuracy); builder.endObject(); return builder; @@ -285,15 +344,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ActualClass that = (ActualClass) o; - return Objects.equals(this.actualClass, that.actualClass) - && this.actualClassDocCount == that.actualClassDocCount + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) && this.accuracy == that.accuracy; } @Override public int hashCode() { - return Objects.hash(actualClass, actualClassDocCount, accuracy); + return Objects.hash(className, accuracy); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 7b9d524abf6f7..3ed9fe44d81bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; @@ -52,13 +53,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { public static final ParseField NAME = new ParseField("multiclass_confusion_matrix"); public static final ParseField SIZE = new ParseField("size"); + public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix"); private static final ConstructingObjectParser PARSER = createParser(); private static ConstructingObjectParser createParser() { ConstructingObjectParser parser = - new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0])); + new ConstructingObjectParser<>( + NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1])); parser.declareInt(optionalConstructorArg(), SIZE); + parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX); return parser; } @@ -71,26 +75,34 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; private static final String OTHER_BUCKET_KEY = "_other_"; + private static final String DEFAULT_AGG_NAME_PREFIX = ""; private static final int DEFAULT_SIZE = 10; private static final int MAX_SIZE = 1000; private final int size; + private final String aggNamePrefix; private List topActualClassNames; private Result result; public MulticlassConfusionMatrix() { - this((Integer) null); + this(null, null); } - public MulticlassConfusionMatrix(@Nullable Integer size) { + public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) { if (size != null && (size <= 0 || size > MAX_SIZE)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE); } this.size = size != null ? size : DEFAULT_SIZE; + this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX; } public MulticlassConfusionMatrix(StreamInput in) throws IOException { this.size = in.readVInt(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + this.aggNamePrefix = in.readString(); + } else { + this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX; + } } @Override @@ -112,7 +124,7 @@ public final Tuple, List> a if (topActualClassNames == null) { // This is step 1 return Tuple.tuple( List.of( - AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) + AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) .field(actualField) .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) .size(size)), @@ -129,10 +141,10 @@ public final Tuple, List> a .toArray(KeyedFilter[]::new); return Tuple.tuple( List.of( - AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) + AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)) .field(actualField), - AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) - .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) + AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual) + .subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted) .otherBucket(true) .otherBucketKey(OTHER_BUCKET_KEY))), List.of()); @@ -142,18 +154,18 @@ public final Tuple, List> a @Override public void process(Aggregations aggs) { - if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { - Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); + if (topActualClassNames == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) { + Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)); topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); } - if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { - Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); - Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); + if (result == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) { + Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)); + Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)); List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); for (Filters.Bucket bucket : filtersAgg.getBuckets()) { String actualClass = bucket.getKeyAsString(); long actualClassDocCount = bucket.getDocCount(); - Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); + Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS)); List predictedClasses = new ArrayList<>(); long otherPredictedClassDocCount = 0; for (Filters.Bucket subBucket : subAgg.getBuckets()) { @@ -172,14 +184,21 @@ public void process(Aggregations aggs) { } } + private String aggName(String aggNameWithoutPrefix) { + return aggNamePrefix + aggNameWithoutPrefix; + } + @Override - public Optional getResult() { + public Optional getResult() { return Optional.ofNullable(result); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeVInt(size); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeString(aggNamePrefix); + } } @Override @@ -195,12 +214,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o; - return Objects.equals(this.size, that.size); + return this.size == that.size + && Objects.equals(this.aggNamePrefix, that.aggNamePrefix); } @Override public int hashCode() { - return Objects.hash(size); + return Objects.hash(size, aggNamePrefix); } public static class Result implements EvaluationMetricResult { @@ -334,6 +354,10 @@ public String getActualClass() { return actualClass; } + public long getActualClassDocCount() { + return actualClassDocCount; + } + public List getPredictedClasses() { return predictedClasses; } @@ -410,6 +434,10 @@ public String getPredictedClass() { return predictedClass; } + public long getCount() { + return count; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(predictedClass); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java index 8fb4c6c02408d..c57cf6e2b01c2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java @@ -8,9 +8,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.util.ArrayList; import java.util.List; @@ -22,23 +22,23 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List actualClasses = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double accuracy = randomDoubleBetween(0.0, 1.0, true); - actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); + classes.add(new PerClassResult(classNames.get(i), accuracy)); } double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); - return new Result(actualClasses, overallAccuracy); + return new Result(classes, overallAccuracy); } @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); + protected Result createTestInstance() { + return createRandom(); } @Override - protected Result createTestInstance() { - return createRandom(); + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index 1809f0e735125..f548fdbfd4c11 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -7,15 +7,11 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; -import java.util.Arrays; import java.util.List; -import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; -import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.hamcrest.Matchers.equalTo; public class AccuracyTests extends AbstractSerializingTestCase { @@ -44,54 +40,36 @@ public static Accuracy createRandom() { return new Accuracy(); } - public void testProcess() { - Aggregations aggs = new Aggregations(Arrays.asList( - mockTerms("classification_classes"), - mockSingleValue("classification_overall_accuracy", 0.8123), - mockSingleValue("some_other_single_metric_agg", 0.2377) - )); - - Accuracy accuracy = new Accuracy(); - accuracy.process(aggs); - - assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(List.of(), 0.8123))); - } - - public void testProcess_GivenMissingAgg() { - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockTerms("classification_classes"), - mockSingleValue("some_other_single_metric_agg", 0.2377) - )); - Accuracy accuracy = new Accuracy(); - expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); - } - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockSingleValue("classification_overall_accuracy", 0.8123), - mockSingleValue("some_other_single_metric_agg", 0.2377) - )); - Accuracy accuracy = new Accuracy(); - expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); - } + public void testComputePerClassAccuracy() { + assertThat( + Accuracy.computePerClassAccuracy( + new MulticlassConfusionMatrix.Result( + List.of( + new MulticlassConfusionMatrix.ActualClass("A", 14, List.of( + new MulticlassConfusionMatrix.PredictedClass("A", 1), + new MulticlassConfusionMatrix.PredictedClass("B", 6), + new MulticlassConfusionMatrix.PredictedClass("C", 4) + ), 3L), + new MulticlassConfusionMatrix.ActualClass("B", 20, List.of( + new MulticlassConfusionMatrix.PredictedClass("A", 5), + new MulticlassConfusionMatrix.PredictedClass("B", 3), + new MulticlassConfusionMatrix.PredictedClass("C", 9) + ), 3L), + new MulticlassConfusionMatrix.ActualClass("C", 17, List.of( + new MulticlassConfusionMatrix.PredictedClass("A", 8), + new MulticlassConfusionMatrix.PredictedClass("B", 2), + new MulticlassConfusionMatrix.PredictedClass("C", 7) + ), 0L)), + 0)), + equalTo( + List.of( + new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives + new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives + new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives + ); } - public void testProcess_GivenAggOfWrongType() { - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockTerms("classification_classes"), - mockTerms("classification_overall_accuracy") - )); - Accuracy accuracy = new Accuracy(); - expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); - } - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockSingleValue("classification_classes", 1.0), - mockSingleValue("classification_overall_accuracy", 0.8123) - )); - Accuracy accuracy = new Accuracy(); - expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); - } + public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() { + expectThrows(AssertionError.class, () -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(List.of(), 1))); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index f145a06c3c894..6713974040c66 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -54,20 +54,23 @@ protected boolean supportsUnknownFields() { public static MulticlassConfusionMatrix createRandom() { Integer size = randomBoolean() ? null : randomIntBetween(1, 1000); - return new MulticlassConfusionMatrix(size); + return new MulticlassConfusionMatrix(size, null); } public void testConstructor_SizeValidationFailures() { { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1)); + ElasticsearchStatusException e = + expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null)); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); } { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0)); + ElasticsearchStatusException e = + expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null)); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); } { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001)); + ElasticsearchStatusException e = + expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null)); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); } } @@ -104,11 +107,11 @@ public void testEvaluate() { List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); - MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); + MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat( result.getConfusionMatrix(), @@ -144,11 +147,11 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); - MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); + MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat( result.getConfusionMatrix(), diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 437b2ddbf5180..0e32bae3d1d2f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -54,8 +54,23 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); assertThat( - evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), - equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + } + + public void testEvaluate_AllMetrics() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new Classification( + ANIMAL_NAME_FIELD, + ANIMAL_NAME_PREDICTION_FIELD, + List.of(new Accuracy(), new MulticlassConfusionMatrix()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(2)); + assertThat(evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + evaluateDataFrameResponse.getMetrics().get(1).getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); } public void testEvaluate_Accuracy_KeywordField() { @@ -69,14 +84,14 @@ public void testEvaluate_Accuracy_KeywordField() { Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - accuracyResult.getActualClasses(), + accuracyResult.getClasses(), equalTo( List.of( - new Accuracy.ActualClass("ant", 15, 1.0 / 15), - new Accuracy.ActualClass("cat", 15, 1.0 / 15), - new Accuracy.ActualClass("dog", 15, 1.0 / 15), - new Accuracy.ActualClass("fox", 15, 1.0 / 15), - new Accuracy.ActualClass("mouse", 15, 1.0 / 15)))); + new Accuracy.PerClassResult("ant", 47.0 / 75), + new Accuracy.PerClassResult("cat", 47.0 / 75), + new Accuracy.PerClassResult("dog", 47.0 / 75), + new Accuracy.PerClassResult("fox", 47.0 / 75), + new Accuracy.PerClassResult("mouse", 47.0 / 75)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); } @@ -91,13 +106,14 @@ public void testEvaluate_Accuracy_IntegerField() { Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - accuracyResult.getActualClasses(), - equalTo(List.of( - new Accuracy.ActualClass("1", 15, 1.0 / 15), - new Accuracy.ActualClass("2", 15, 2.0 / 15), - new Accuracy.ActualClass("3", 15, 3.0 / 15), - new Accuracy.ActualClass("4", 15, 4.0 / 15), - new Accuracy.ActualClass("5", 15, 5.0 / 15)))); + accuracyResult.getClasses(), + equalTo( + List.of( + new Accuracy.PerClassResult("1", 57.0 / 75), + new Accuracy.PerClassResult("2", 54.0 / 75), + new Accuracy.PerClassResult("3", 51.0 / 75), + new Accuracy.PerClassResult("4", 48.0 / 75), + new Accuracy.PerClassResult("5", 45.0 / 75)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); } @@ -112,13 +128,23 @@ public void testEvaluate_Accuracy_BooleanField() { Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - accuracyResult.getActualClasses(), - equalTo(List.of( - new Accuracy.ActualClass("true", 45, 27.0 / 45), - new Accuracy.ActualClass("false", 30, 18.0 / 30)))); + accuracyResult.getClasses(), + equalTo( + List.of( + new Accuracy.PerClassResult("false", 18.0 / 30), + new Accuracy.PerClassResult("true", 27.0 / 45)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } + public void testEvaluate_Accuracy_CardinalityTooHigh() { + ElasticsearchStatusException e = + expectThrows( + ElasticsearchStatusException.class, + () -> evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy(4))))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + } + public void testEvaluate_Precision() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( @@ -250,7 +276,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3)))); + new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 87fa5c30b0755..966718fb55d1e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -461,12 +461,10 @@ private void assertEvaluation(String dependentVariable, List dependentVar { // Accuracy Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - List actualClasses = accuracyResult.getActualClasses(); - assertThat( - actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()), - equalTo(dependentVariableValuesAsStrings)); - actualClasses.forEach( - actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); + for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) { + assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); + assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + } } { // MulticlassConfusionMatrix diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 95a7ef4e33218..2b16b79ac84b4 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -620,16 +620,13 @@ setup: - match: classification.accuracy: - actual_classes: - - actual_class: "cat" - actual_class_doc_count: 3 - accuracy: 0.6666666666666666 # 2 out of 3 - - actual_class: "dog" - actual_class_doc_count: 3 - accuracy: 0.6666666666666666 # 2 out of 3 - - actual_class: "mouse" - actual_class_doc_count: 2 - accuracy: 0.5 # 1 out of 2 + classes: + - class_name: "cat" + accuracy: 0.625 # 5 out of 8 + - class_name: "dog" + accuracy: 0.75 # 6 out of 8 + - class_name: "mouse" + accuracy: 0.875 # 7 out of 8 overall_accuracy: 0.625 # 5 out of 8 --- "Test classification precision": From 00e590e5fcff2eddddb2cada5648d2fca6d17171 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 19 Dec 2019 16:58:24 +0100 Subject: [PATCH 2/7] Fix testEvaluate_AllMetrics --- .../classification/AccuracyResultTests.java | 8 ++++---- .../ClassificationEvaluationIT.java | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java index c57cf6e2b01c2..176aa6e9a309b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java @@ -32,13 +32,13 @@ public static Result createRandom() { } @Override - protected Result createTestInstance() { - return createRandom(); + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); } @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); + protected Result createTestInstance() { + return createRandom(); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 0e32bae3d1d2f..be1961c36d4e0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; @@ -21,6 +22,8 @@ import java.util.List; +import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -52,9 +55,9 @@ public void testEvaluate_DefaultMetrics() { evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); assertThat( - evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + contains(MulticlassConfusionMatrix.NAME.getPreferredName())); } public void testEvaluate_AllMetrics() { @@ -64,13 +67,16 @@ public void testEvaluate_AllMetrics() { new Classification( ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, - List.of(new Accuracy(), new MulticlassConfusionMatrix()))); + List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(2)); - assertThat(evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - evaluateDataFrameResponse.getMetrics().get(1).getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + contains( + Accuracy.NAME.getPreferredName(), + MulticlassConfusionMatrix.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName())); } public void testEvaluate_Accuracy_KeywordField() { From 28cac9a45adcf1082072046e1ca88762198b80e0 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 19 Dec 2019 17:42:02 +0100 Subject: [PATCH 3/7] Serialize maxClassesCardinality internal parameter --- .../core/ml/dataframe/evaluation/classification/Accuracy.java | 3 ++- .../core/ml/dataframe/evaluation/classification/Precision.java | 1 + .../core/ml/dataframe/evaluation/classification/Recall.java | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 4e799648f99d1..f444f4253dc55 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -96,7 +96,7 @@ public Accuracy(@Nullable Integer maxClassesCardinality) { } public Accuracy(StreamInput in) throws IOException { - this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; + this.maxClassesCardinality = in.readVInt(); this.matrix = new MulticlassConfusionMatrix(in); } @@ -184,6 +184,7 @@ static List computePerClassAccuracy(MulticlassConfusionMatrix.Re @Override public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(maxClassesCardinality); matrix.writeTo(out); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index dd04f23710118..3e29861c92e11 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -163,6 +163,7 @@ public Optional getResult() { @Override public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(maxClassesCardinality); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 01bdbe6db230b..655855f38a12d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -137,6 +137,7 @@ public Optional getResult() { @Override public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(maxClassesCardinality); } @Override From 1b84d8991e194f6a4d4cd7672aeb2c2af5d539d4 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 19 Dec 2019 17:56:28 +0100 Subject: [PATCH 4/7] Revert "Serialize maxClassesCardinality internal parameter" This reverts commit 256d44d4525bcd3f130cf49690271bfd030db727. --- .../core/ml/dataframe/evaluation/classification/Accuracy.java | 3 +-- .../core/ml/dataframe/evaluation/classification/Precision.java | 1 - .../core/ml/dataframe/evaluation/classification/Recall.java | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index f444f4253dc55..4e799648f99d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -96,7 +96,7 @@ public Accuracy(@Nullable Integer maxClassesCardinality) { } public Accuracy(StreamInput in) throws IOException { - this.maxClassesCardinality = in.readVInt(); + this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; this.matrix = new MulticlassConfusionMatrix(in); } @@ -184,7 +184,6 @@ static List computePerClassAccuracy(MulticlassConfusionMatrix.Re @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(maxClassesCardinality); matrix.writeTo(out); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 3e29861c92e11..dd04f23710118 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -163,7 +163,6 @@ public Optional getResult() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(maxClassesCardinality); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 655855f38a12d..01bdbe6db230b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -137,7 +137,6 @@ public Optional getResult() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(maxClassesCardinality); } @Override From c8b0259403a4e6c30fa98846990b79d66ff68703 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 20 Dec 2019 09:12:19 +0100 Subject: [PATCH 5/7] Use SetOnce<> wrapper for storing internal metrics state --- .../evaluation/classification/Accuracy.java | 25 ++++++++-------- .../MulticlassConfusionMatrix.java | 23 ++++++++------- .../evaluation/classification/Precision.java | 29 ++++++++++--------- .../evaluation/classification/Recall.java | 19 ++++++------ 4 files changed, 50 insertions(+), 46 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 4e799648f99d1..471714e4ede95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; @@ -81,9 +82,9 @@ public static Accuracy fromXContent(XContentParser parser) { private final int maxClassesCardinality; private final MulticlassConfusionMatrix matrix; - private String actualField; - private Double overallAccuracy; - private Result result; + private final SetOnce actualField = new SetOnce<>(); + private final SetOnce overallAccuracy = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public Accuracy() { this((Integer) null); @@ -113,13 +114,13 @@ public String getName() { @Override public final Tuple, List> aggs(String actualField, String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. - this.actualField = actualField; + this.actualField.trySet(actualField); List aggs = new ArrayList<>(); List pipelineAggs = new ArrayList<>(); - if (overallAccuracy == null) { + if (overallAccuracy.get() == null) { aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); } - if (result == null) { + if (result.get() == null) { Tuple, List> matrixAggs = matrix.aggs(actualField, predictedField); aggs.addAll(matrixAggs.v1()); pipelineAggs.addAll(matrixAggs.v2()); @@ -129,25 +130,25 @@ public final Tuple, List> a @Override public void process(Aggregations aggs) { - if (overallAccuracy == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { + if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); - overallAccuracy = overallAccuracyAgg.value(); + overallAccuracy.set(overallAccuracyAgg.value()); } matrix.process(aggs); - if (result == null && matrix.getResult().isPresent()) { + if (result.get() == null && matrix.getResult().isPresent()) { if (matrix.getResult().get().getOtherActualClassCount() > 0) { // This means there were more than {@code maxClassesCardinality} buckets. // We cannot calculate per-class accuracy accurately, so we fail. throw ExceptionsHelper.badRequestException( - "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField); + "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get()); } - result = new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy); + result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get())); } } @Override public Optional getResult() { - return Optional.ofNullable(result); + return Optional.ofNullable(result.get()); } /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 3ed9fe44d81bc..3c4bf1f1cb5ca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -81,8 +82,8 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { private final int size; private final String aggNamePrefix; - private List topActualClassNames; - private Result result; + private final SetOnce> topActualClassNames = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public MulticlassConfusionMatrix() { this(null, null); @@ -121,7 +122,7 @@ public int getSize() { @Override public final Tuple, List> aggs(String actualField, String predictedField) { - if (topActualClassNames == null) { // This is step 1 + if (topActualClassNames.get() == null) { // This is step 1 return Tuple.tuple( List.of( AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) @@ -130,13 +131,13 @@ public final Tuple, List> a .size(size)), List.of()); } - if (result == null) { // This is step 2 + if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersActual = - topActualClassNames.stream() + topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) .toArray(KeyedFilter[]::new); KeyedFilter[] keyedFiltersPredicted = - topActualClassNames.stream() + topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); return Tuple.tuple( @@ -154,11 +155,11 @@ public final Tuple, List> a @Override public void process(Aggregations aggs) { - if (topActualClassNames == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) { + if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) { Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)); - topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); + topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList())); } - if (result == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) { + if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) { Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)); Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)); List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); @@ -180,7 +181,7 @@ public void process(Aggregations aggs) { predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); } - result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); + result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0))); } } @@ -190,7 +191,7 @@ private String aggName(String aggNameWithoutPrefix) { @Override public Optional getResult() { - return Optional.ofNullable(result); + return Optional.ofNullable(result.get()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index dd04f23710118..c3da03f080be3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -76,9 +77,9 @@ public static Precision fromXContent(XContentParser parser) { private static final int MAX_CLASSES_CARDINALITY = 1000; - private String actualField; - private List topActualClassNames; - private EvaluationMetricResult result; + private final SetOnce actualField = new SetOnce<>(); + private final SetOnce> topActualClassNames = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public Precision() {} @@ -97,8 +98,8 @@ public String getName() { @Override public final Tuple, List> aggs(String actualField, String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. - this.actualField = actualField; - if (topActualClassNames == null) { // This is step 1 + this.actualField.trySet(actualField); + if (topActualClassNames.get() == null) { // This is step 1 return Tuple.tuple( List.of( AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) @@ -109,7 +110,7 @@ public final Tuple, List> a } if (result == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = - topActualClassNames.stream() + topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); Script script = buildScript(actualField, predictedField); @@ -127,18 +128,18 @@ public final Tuple, List> a @Override public void process(Aggregations aggs) { - if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { + if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { // This means there were more than {@code maxClassesCardinality} buckets. // We cannot calculate average precision accurately, so we fail. throw ExceptionsHelper.badRequestException( - "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField); + "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get()); } - topActualClassNames = - topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); + topActualClassNames.set( + topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList())); } - if (result == null && + if (result.get() == null && aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters && aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME); @@ -152,13 +153,13 @@ public void process(Aggregations aggs) { classes.add(new PerClassResult(className, precision)); } } - result = new Result(classes, avgPrecisionAgg.value()); + result.set(new Result(classes, avgPrecisionAgg.value())); } } @Override - public Optional getResult() { - return Optional.ofNullable(result); + public Optional getResult() { + return Optional.ofNullable(result.get()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 01bdbe6db230b..043a6f7db42b7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -70,8 +71,8 @@ public static Recall fromXContent(XContentParser parser) { private static final int MAX_CLASSES_CARDINALITY = 1000; - private String actualField; - private EvaluationMetricResult result; + private final SetOnce actualField = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public Recall() {} @@ -90,8 +91,8 @@ public String getName() { @Override public final Tuple, List> aggs(String actualField, String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. - this.actualField = actualField; - if (result != null) { + this.actualField.trySet(actualField); + if (result.get() != null) { return Tuple.tuple(List.of(), List.of()); } Script script = buildScript(actualField, predictedField); @@ -109,7 +110,7 @@ public final Tuple, List> a @Override public void process(Aggregations aggs) { - if (result == null && + if (result.get() == null && aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms && aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); @@ -117,7 +118,7 @@ public void process(Aggregations aggs) { // This means there were more than {@code maxClassesCardinality} buckets. // We cannot calculate average recall accurately, so we fail. throw ExceptionsHelper.badRequestException( - "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField); + "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get()); } NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME); List classes = new ArrayList<>(byActualClassAgg.getBuckets().size()); @@ -126,13 +127,13 @@ public void process(Aggregations aggs) { NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME); classes.add(new PerClassResult(className, recallAgg.value())); } - result = new Result(classes, avgRecallAgg.value()); + result.set(new Result(classes, avgRecallAgg.value())); } } @Override - public Optional getResult() { - return Optional.ofNullable(result); + public Optional getResult() { + return Optional.ofNullable(result.get()); } @Override From 44ea9a845e029fb6dec15743c38ae11c98676484 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 20 Dec 2019 10:15:14 +0100 Subject: [PATCH 6/7] Get rid of maxClassesCardinality internal parameter --- .../evaluation/classification/Accuracy.java | 13 +-- .../MulticlassConfusionMatrix.java | 8 +- .../evaluation/classification/Precision.java | 2 +- .../classification/AccuracyTests.java | 87 +++++++++++++++++++ .../MulticlassConfusionMatrixTests.java | 29 ++++--- .../ClassificationEvaluationIT.java | 9 -- 6 files changed, 109 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 471714e4ede95..c6636329a65d9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.apache.lucene.util.SetOnce; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -78,26 +77,18 @@ public static Accuracy fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + private static final int MAX_CLASSES_CARDINALITY = 1000; - private final int maxClassesCardinality; private final MulticlassConfusionMatrix matrix; private final SetOnce actualField = new SetOnce<>(); private final SetOnce overallAccuracy = new SetOnce<>(); private final SetOnce result = new SetOnce<>(); public Accuracy() { - this((Integer) null); - } - - // Visible for testing - public Accuracy(@Nullable Integer maxClassesCardinality) { - this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; - this.matrix = new MulticlassConfusionMatrix(this.maxClassesCardinality, NAME.getPreferredName() + "_"); + this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_"); } public Accuracy(StreamInput in) throws IOException { - this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; this.matrix = new MulticlassConfusionMatrix(in); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 3c4bf1f1cb5ca..e5a4de1605da0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -71,10 +71,10 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; - private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; - private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; - private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; + static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; + static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; + static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; + static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; private static final String OTHER_BUCKET_KEY = "_other_"; private static final String DEFAULT_AGG_NAME_PREFIX = ""; private static final int DEFAULT_SIZE = 10; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index c3da03f080be3..87b45949b85ba 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -108,7 +108,7 @@ public final Tuple, List> a .size(MAX_CLASSES_CARDINALITY)), List.of()); } - if (result == null) { // This is step 2 + if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index f548fdbfd4c11..cac591a17d303 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -5,13 +5,26 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.io.IOException; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; public class AccuracyTests extends AbstractSerializingTestCase { @@ -40,6 +53,80 @@ public static Accuracy createRandom() { return new Accuracy(); } + public void testProcess() { + Aggregations aggs = new Aggregations(List.of( + mockTerms( + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockTermsBucket("dog", new Aggregations(List.of())), + mockTermsBucket("cat", new Aggregations(List.of()))), + 100L), + mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockFiltersBucket( + "dog", + 30, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 70, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L), + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); + + Accuracy accuracy = new Accuracy(); + accuracy.process(aggs); + + assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty())); + + Result result = accuracy.getResult().get(); + assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + result.getClasses(), + equalTo( + List.of( + new PerClassResult("dog", 0.5), + new PerClassResult("cat", 0.5)))); + assertThat(result.getOverallAccuracy(), equalTo(0.5)); + } + + public void testProcess_GivenCardinalityTooHigh() { + Aggregations aggs = new Aggregations(List.of( + mockTerms( + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockTermsBucket("dog", new Aggregations(List.of())), + mockTermsBucket("cat", new Aggregations(List.of()))), + 100L), + mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + List.of( + mockFiltersBucket( + "dog", + 30, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 70, + new Aggregations(List.of(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L), + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); + + Accuracy accuracy = new Accuracy(); + accuracy.aggs("foo", "bar"); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); + assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); + } + public void testComputePerClassAccuracy() { assertThat( Accuracy.computePerClassAccuracy( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index 6713974040c66..8c02a3c2c6fc3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; import java.io.IOException; import java.util.List; @@ -85,34 +86,34 @@ public void testAggs() { public void testEvaluate() { Aggregations aggs = new Aggregations(List.of( mockTerms( - "multiclass_confusion_matrix_step_1_by_actual_class", + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockTermsBucket("dog", new Aggregations(List.of())), mockTermsBucket("cat", new Aggregations(List.of()))), 0L), mockFilters( - "multiclass_confusion_matrix_step_2_by_actual_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockFiltersBucket( "dog", 30, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", 70, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), - mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); + mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get(); - assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( result.getConfusionMatrix(), equalTo( @@ -125,34 +126,34 @@ public void testEvaluate() { public void testEvaluate_OtherClassesCountGreaterThanZero() { Aggregations aggs = new Aggregations(List.of( mockTerms( - "multiclass_confusion_matrix_step_1_by_actual_class", + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockTermsBucket("dog", new Aggregations(List.of())), mockTermsBucket("cat", new Aggregations(List.of()))), 100L), mockFilters( - "multiclass_confusion_matrix_step_2_by_actual_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, List.of( mockFiltersBucket( "dog", 30, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", 85, new Aggregations(List.of(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), - mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); + mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get(); - assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( result.getConfusionMatrix(), equalTo( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index be1961c36d4e0..da5439f6298dc 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -142,15 +142,6 @@ public void testEvaluate_Accuracy_BooleanField() { assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } - public void testEvaluate_Accuracy_CardinalityTooHigh() { - ElasticsearchStatusException e = - expectThrows( - ElasticsearchStatusException.class, - () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy(4))))); - assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); - } - public void testEvaluate_Precision() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( From 3f9a60a3970249c2c30f5b3cd3e5f8376637e045 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 20 Dec 2019 11:47:33 +0100 Subject: [PATCH 7/7] Fix member variable reference in comment --- .../core/ml/dataframe/evaluation/classification/Precision.java | 2 +- .../core/ml/dataframe/evaluation/classification/Recall.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 87b45949b85ba..73c7723c86b96 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -131,7 +131,7 @@ public void process(Aggregations aggs) { if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { - // This means there were more than {@code maxClassesCardinality} buckets. + // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // We cannot calculate average precision accurately, so we fail. throw ExceptionsHelper.badRequestException( "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 043a6f7db42b7..0358820cc509c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -115,7 +115,7 @@ public void process(Aggregations aggs) { aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); if (byActualClassAgg.getSumOfOtherDocCounts() > 0) { - // This means there were more than {@code maxClassesCardinality} buckets. + // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // We cannot calculate average recall accurately, so we fail. throw ExceptionsHelper.badRequestException( "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get());