Skip to content

Commit 367a135

Browse files
[7.x][ML] Extend default evaluation metrics to all available (#63939) (#63965)
This commit extends the set of default metrics for the data frame analytics evaluation API to all available metrics. The motivation is that if the user skips setting an explicit set of metrics, they get most of the evaluation offering. Backport of #63939
1 parent b6db094 commit 367a135

File tree

9 files changed

+56
-5
lines changed

9 files changed

+56
-5
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public Classification(String actualField,
9393
}
9494

9595
private static List<EvaluationMetric> defaultMetrics() {
96-
return Arrays.asList(new MulticlassConfusionMatrix());
96+
return Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall());
9797
}
9898

9999
public Classification(StreamInput in) throws IOException {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ public Huber(StreamInput in) throws IOException {
8080
this.delta = in.readDouble();
8181
}
8282

83+
public Huber() {
84+
this(DEFAULT_DELTA);
85+
}
86+
8387
public Huber(@Nullable Double delta) {
8488
this.delta = delta != null ? delta : DEFAULT_DELTA;
8589
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public Regression(String actualField, String predictedField, @Nullable List<Eval
7676
}
7777

7878
private static List<EvaluationMetric> defaultMetrics() {
79-
return Arrays.asList(new MeanSquaredError(), new RSquared());
79+
return Arrays.asList(new MeanSquaredError(), new RSquared(), new Huber());
8080
}
8181

8282
public Regression(StreamInput in) throws IOException {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
4242
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
43+
import static org.hamcrest.Matchers.containsInAnyOrder;
4344
import static org.hamcrest.Matchers.equalTo;
4445
import static org.hamcrest.Matchers.greaterThan;
4546
import static org.hamcrest.Matchers.is;
@@ -110,6 +111,14 @@ public void testConstructor_GivenEmptyMetrics() {
110111
assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics"));
111112
}
112113

114+
public void testConstructor_GivenDefaultMetrics() {
115+
Classification classification = new Classification("actual", "predicted", null, null);
116+
117+
List<EvaluationMetric> metrics = classification.getMetrics();
118+
119+
assertThat(metrics, containsInAnyOrder(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()));
120+
}
121+
113122
public void testGetFields() {
114123
Classification evaluation = new Classification("foo", "bar", "results", null);
115124
EvaluationFields fields = evaluation.getFields();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Collections;
2626
import java.util.List;
2727

28+
import static org.hamcrest.Matchers.containsInAnyOrder;
2829
import static org.hamcrest.Matchers.equalTo;
2930
import static org.hamcrest.Matchers.greaterThan;
3031
import static org.hamcrest.Matchers.is;
@@ -89,6 +90,17 @@ public void testConstructor_GivenEmptyMetrics() {
8990
assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics"));
9091
}
9192

93+
public void testConstructor_GivenDefaultMetrics() {
94+
OutlierDetection outlierDetection = new OutlierDetection("actual", "predicted", null);
95+
96+
List<EvaluationMetric> metrics = outlierDetection.getMetrics();
97+
98+
assertThat(metrics, containsInAnyOrder(new AucRoc(false),
99+
new Precision(Arrays.asList(0.25, 0.5, 0.75)),
100+
new Recall(Arrays.asList(0.25, 0.5, 0.75)),
101+
new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75))));
102+
}
103+
92104
public void testGetFields() {
93105
OutlierDetection evaluation = new OutlierDetection("foo", "bar", null);
94106
EvaluationFields fields = evaluation.getFields();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Collections;
2626
import java.util.List;
2727

28+
import static org.hamcrest.Matchers.containsInAnyOrder;
2829
import static org.hamcrest.Matchers.equalTo;
2930
import static org.hamcrest.Matchers.greaterThan;
3031
import static org.hamcrest.Matchers.is;
@@ -76,6 +77,14 @@ public void testConstructor_GivenEmptyMetrics() {
7677
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
7778
}
7879

80+
public void testConstructor_GivenDefaultMetrics() {
81+
Regression regression = new Regression("actual", "predicted", null);
82+
83+
List<EvaluationMetric> metrics = regression.getMetrics();
84+
85+
assertThat(metrics, containsInAnyOrder(new Huber(), new MeanSquaredError(), new RSquared()));
86+
}
87+
7988
public void testGetFields() {
8089
Regression evaluation = new Regression("foo", "bar", null);
8190
EvaluationFields fields = evaluation.getFields();

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import static java.util.stream.Collectors.toList;
3434
import static org.hamcrest.Matchers.closeTo;
3535
import static org.hamcrest.Matchers.contains;
36+
import static org.hamcrest.Matchers.containsInAnyOrder;
3637
import static org.hamcrest.Matchers.containsString;
3738
import static org.hamcrest.Matchers.empty;
3839
import static org.hamcrest.Matchers.equalTo;
@@ -82,7 +83,13 @@ public void testEvaluate_DefaultMetrics() {
8283
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
8384
assertThat(
8485
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
85-
contains(MulticlassConfusionMatrix.NAME.getPreferredName()));
86+
containsInAnyOrder(
87+
MulticlassConfusionMatrix.NAME.getPreferredName(),
88+
Accuracy.NAME.getPreferredName(),
89+
Precision.NAME.getPreferredName(),
90+
Recall.NAME.getPreferredName()
91+
)
92+
);
8693
}
8794

8895
public void testEvaluate_AllMetrics() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import static java.util.stream.Collectors.toList;
2626
import static org.hamcrest.Matchers.closeTo;
2727
import static org.hamcrest.Matchers.contains;
28+
import static org.hamcrest.Matchers.containsInAnyOrder;
2829
import static org.hamcrest.Matchers.equalTo;
2930
import static org.hamcrest.Matchers.hasSize;
3031

@@ -53,7 +54,12 @@ public void testEvaluate_DefaultMetrics() {
5354
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
5455
assertThat(
5556
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
56-
contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName()));
57+
containsInAnyOrder(
58+
MeanSquaredError.NAME.getPreferredName(),
59+
RSquared.NAME.getPreferredName(),
60+
Huber.NAME.getPreferredName()
61+
)
62+
);
5763
}
5864

5965
public void testEvaluate_AllMetrics() {

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,10 @@ setup:
938938
}
939939
940940
- is_true: classification.multiclass_confusion_matrix
941+
- is_true: classification.accuracy
942+
- is_true: classification.precision
943+
- is_true: classification.recall
944+
- is_false: classification.auc_roc
941945
---
942946
"Test classification given missing actual_field":
943947
- do:
@@ -1104,8 +1108,8 @@ setup:
11041108
11051109
- match: { regression.mse.value: 28.67749840974834 }
11061110
- match: { regression.r_squared.value: 0.8551031778603486 }
1111+
- match: { regression.huber.value: 1.9205280586939963 }
11071112
- is_false: regression.msle.value
1108-
- is_false: regression.huber.value
11091113
---
11101114
"Test regression given missing actual_field":
11111115
- do:

0 commit comments

Comments
 (0)