Skip to content

Commit c82d9c5

Browse files
authored
[ML] Adds support for regression.mean_squared_error to eval API (#44140) (#44218)
* [ML] Adds support for regression.mean_squared_error to eval API * addressing PR comments * fixing tests
1 parent 1636701 commit c82d9c5

File tree

17 files changed

+1069
-20
lines changed

17 files changed

+1069
-20
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
*/
1919
package org.elasticsearch.client.ml.dataframe.evaluation;
2020

21+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
22+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
2123
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
2224
import org.elasticsearch.common.ParseField;
2325
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -38,19 +40,24 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
3840
// Evaluations
3941
new NamedXContentRegistry.Entry(
4042
Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
43+
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
4144
// Evaluation metrics
4245
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
4346
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
4447
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
4548
new NamedXContentRegistry.Entry(
4649
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
50+
new NamedXContentRegistry.Entry(
51+
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
4752
// Evaluation metrics results
4853
new NamedXContentRegistry.Entry(
4954
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
5055
new NamedXContentRegistry.Entry(
5156
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
5257
new NamedXContentRegistry.Entry(
5358
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
59+
new NamedXContentRegistry.Entry(
60+
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
5461
new NamedXContentRegistry.Entry(
5562
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
5663
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
24+
import org.elasticsearch.common.xcontent.ObjectParser;
25+
import org.elasticsearch.common.xcontent.ToXContent;
26+
import org.elasticsearch.common.xcontent.XContentBuilder;
27+
import org.elasticsearch.common.xcontent.XContentParser;
28+
29+
import java.io.IOException;
30+
import java.util.Objects;
31+
32+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
33+
34+
/**
35+
* Calculates the mean squared error between two known numerical fields.
36+
*
37+
* equation: mse = 1/n * Σ(y - y´)^2
38+
*/
39+
public class MeanSquaredErrorMetric implements EvaluationMetric {
40+
41+
public static final String NAME = "mean_squared_error";
42+
43+
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER =
44+
new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new);
45+
46+
public static MeanSquaredErrorMetric fromXContent(XContentParser parser) {
47+
return PARSER.apply(parser, null);
48+
}
49+
50+
public MeanSquaredErrorMetric() {
51+
52+
}
53+
54+
@Override
55+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
56+
builder.startObject();
57+
builder.endObject();
58+
return builder;
59+
}
60+
61+
@Override
62+
public boolean equals(Object o) {
63+
if (this == o) return true;
64+
if (o == null || getClass() != o.getClass()) return false;
65+
return true;
66+
}
67+
68+
@Override
69+
public int hashCode() {
70+
// create static hash code from name as there are currently no unique fields per class instance
71+
return Objects.hashCode(NAME);
72+
}
73+
74+
@Override
75+
public String getName() {
76+
return NAME;
77+
}
78+
79+
public static class Result implements EvaluationMetric.Result {
80+
81+
public static final ParseField ERROR = new ParseField("error");
82+
private final double error;
83+
84+
public static Result fromXContent(XContentParser parser) {
85+
return PARSER.apply(parser, null);
86+
}
87+
88+
private static final ConstructingObjectParser<Result, Void> PARSER =
89+
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
90+
91+
static {
92+
PARSER.declareDouble(constructorArg(), ERROR);
93+
}
94+
95+
public Result(double error) {
96+
this.error = error;
97+
}
98+
99+
@Override
100+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
101+
builder.startObject();
102+
builder.field(ERROR.getPreferredName(), error);
103+
builder.endObject();
104+
return builder;
105+
}
106+
107+
public double getError() {
108+
return error;
109+
}
110+
111+
@Override
112+
public String getMetricName() {
113+
return NAME;
114+
}
115+
116+
@Override
117+
public boolean equals(Object o) {
118+
if (this == o) return true;
119+
if (o == null || getClass() != o.getClass()) return false;
120+
Result that = (Result) o;
121+
return Objects.equals(that.error, this.error);
122+
}
123+
124+
@Override
125+
public int hashCode() {
126+
return Objects.hash(error);
127+
}
128+
}
129+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
22+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
23+
import org.elasticsearch.common.Nullable;
24+
import org.elasticsearch.common.ParseField;
25+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.common.xcontent.ToXContent;
27+
import org.elasticsearch.common.xcontent.XContentBuilder;
28+
import org.elasticsearch.common.xcontent.XContentParser;
29+
30+
import java.io.IOException;
31+
import java.util.Arrays;
32+
import java.util.List;
33+
import java.util.Objects;
34+
35+
/**
36+
* Evaluation of regression results.
37+
*/
38+
public class Regression implements Evaluation {
39+
40+
public static final String NAME = "regression";
41+
42+
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
43+
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
44+
private static final ParseField METRICS = new ParseField("metrics");
45+
46+
@SuppressWarnings("unchecked")
47+
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
48+
NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
49+
50+
static {
51+
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
52+
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
53+
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
54+
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
55+
}
56+
57+
public static Regression fromXContent(XContentParser parser) {
58+
return PARSER.apply(parser, null);
59+
}
60+
61+
/**
62+
* The field containing the actual value
63+
* The value of this field is assumed to be numeric
64+
*/
65+
private final String actualField;
66+
67+
/**
68+
* The field containing the predicted value
69+
* The value of this field is assumed to be numeric
70+
*/
71+
private final String predictedField;
72+
73+
/**
74+
* The list of metrics to calculate
75+
*/
76+
private final List<EvaluationMetric> metrics;
77+
78+
public Regression(String actualField, String predictedField) {
79+
this(actualField, predictedField, (List<EvaluationMetric>)null);
80+
}
81+
82+
public Regression(String actualField, String predictedField, EvaluationMetric... metrics) {
83+
this(actualField, predictedField, Arrays.asList(metrics));
84+
}
85+
86+
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
87+
this.actualField = actualField;
88+
this.predictedField = predictedField;
89+
this.metrics = metrics;
90+
}
91+
92+
@Override
93+
public String getName() {
94+
return NAME;
95+
}
96+
97+
@Override
98+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
99+
builder.startObject();
100+
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
101+
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
102+
103+
if (metrics != null) {
104+
builder.startObject(METRICS.getPreferredName());
105+
for (EvaluationMetric metric : metrics) {
106+
builder.field(metric.getName(), metric);
107+
}
108+
builder.endObject();
109+
}
110+
111+
builder.endObject();
112+
return builder;
113+
}
114+
115+
@Override
116+
public boolean equals(Object o) {
117+
if (this == o) return true;
118+
if (o == null || getClass() != o.getClass()) return false;
119+
Regression that = (Regression) o;
120+
return Objects.equals(that.actualField, this.actualField)
121+
&& Objects.equals(that.predictedField, this.predictedField)
122+
&& Objects.equals(that.metrics, this.metrics);
123+
}
124+
125+
@Override
126+
public int hashCode() {
127+
return Objects.hash(actualField, predictedField, metrics);
128+
}
129+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
124124
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
125125
import org.elasticsearch.client.ml.dataframe.QueryConfig;
126+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
127+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
126128
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
127129
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
128130
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@@ -1578,6 +1580,33 @@ public void testEvaluateDataFrame() throws IOException {
15781580
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
15791581
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
15801582
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
1583+
1584+
String regressionIndex = "evaluate-regression-test-index";
1585+
createIndex(regressionIndex, mappingForRegression());
1586+
BulkRequest regressionBulk = new BulkRequest()
1587+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
1588+
.add(docForRegression(regressionIndex, 0.3, 0.1)) // #0
1589+
.add(docForRegression(regressionIndex, 0.3, 0.2)) // #1
1590+
.add(docForRegression(regressionIndex, 0.3, 0.3)) // #2
1591+
.add(docForRegression(regressionIndex, 0.3, 0.4)) // #3
1592+
.add(docForRegression(regressionIndex, 0.3, 0.7)) // #4
1593+
.add(docForRegression(regressionIndex, 0.5, 0.2)) // #5
1594+
.add(docForRegression(regressionIndex, 0.5, 0.3)) // #6
1595+
.add(docForRegression(regressionIndex, 0.5, 0.4)) // #7
1596+
.add(docForRegression(regressionIndex, 0.5, 0.8)) // #8
1597+
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
1598+
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
1599+
1600+
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression));
1601+
1602+
evaluateDataFrameResponse =
1603+
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
1604+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
1605+
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
1606+
1607+
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
1608+
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
1609+
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
15811610
}
15821611

15831612
private static XContentBuilder defaultMappingForTest() throws IOException {
@@ -1615,6 +1644,28 @@ private static IndexRequest docForClassification(String indexName, boolean isTru
16151644
.source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
16161645
}
16171646

1647+
private static final String actualRegression = "regression_actual";
1648+
private static final String probabilityRegression = "regression_prob";
1649+
1650+
private static XContentBuilder mappingForRegression() throws IOException {
1651+
return XContentFactory.jsonBuilder().startObject()
1652+
.startObject("properties")
1653+
.startObject(actualRegression)
1654+
.field("type", "double")
1655+
.endObject()
1656+
.startObject(probabilityRegression)
1657+
.field("type", "double")
1658+
.endObject()
1659+
.endObject()
1660+
.endObject();
1661+
}
1662+
1663+
private static IndexRequest docForRegression(String indexName, double act, double p) {
1664+
return new IndexRequest()
1665+
.index(indexName)
1666+
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
1667+
}
1668+
16181669
private void createIndex(String indexName, XContentBuilder mapping) throws IOException {
16191670
highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT);
16201671
}

0 commit comments

Comments
 (0)