diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index cd5c2abdf5627..0a37182aff922 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -22,6 +22,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -97,6 +98,10 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass EvaluationMetric.class, new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)), MeanSquaredErrorMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)), + MeanSquaredLogarithmicErrorMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), @@ -140,6 +145,10 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass EvaluationMetric.Result.class, new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)), MeanSquaredErrorMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)), + MeanSquaredLogarithmicErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java index 5b961dacbcc52..0b795bf7422f7 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java @@ -40,16 +40,13 @@ public class MeanSquaredErrorMetric implements EvaluationMetric { public static final String NAME = "mean_squared_error"; - private static final ObjectParser PARSER = - new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new); + private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new); public static MeanSquaredErrorMetric fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - public MeanSquaredErrorMetric() { - - } + public MeanSquaredErrorMetric() {} @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java new file mode 100644 index 0000000000000..66a0666c8ecce --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java @@ -0,0 +1,142 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Calculates the mean squared error between two known numerical fields. + * + * equation: msle = 1/n * Σ(log(y + offset) - log(y´ + offset))^2 + * where offset is used to make sure the argument to log function is always positive + */ +public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric { + + public static final String NAME = "mean_squared_logarithmic_error"; + + public static final ParseField OFFSET = new ParseField("offset"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, args -> new MeanSquaredLogarithmicErrorMetric((Double) args[0])); + + static { + PARSER.declareDouble(optionalConstructorArg(), OFFSET); + } + + public static MeanSquaredLogarithmicErrorMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Double offset; + + public MeanSquaredLogarithmicErrorMetric(@Nullable Double offset) { + this.offset = offset; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (offset != null) { + builder.field(OFFSET.getPreferredName(), offset); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MeanSquaredLogarithmicErrorMetric that = (MeanSquaredLogarithmicErrorMetric) o; + return Objects.equals(this.offset, that.offset); + } + + @Override + public int hashCode() { + return Objects.hash(offset); + } + + public static class Result implements EvaluationMetric.Result { + + public static final ParseField ERROR = new ParseField("error"); + private final double error; + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0])); + + static { + PARSER.declareDouble(constructorArg(), ERROR); + } + + public Result(double error) { + this.error = error; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ERROR.getPreferredName(), error); + builder.endObject(); + return builder; + } + + public double getError() { + return error; + } + + @Override + public String getMetricName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(that.error, this.error); + } + + @Override + public int hashCode() { + return Objects.hash(error); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java index 968489a30389f..a38980c3e49bb 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java @@ -42,16 +42,13 @@ public class RSquaredMetric implements EvaluationMetric { public static final String NAME = "r_squared"; - private static final ObjectParser PARSER = - new ObjectParser<>("r_squared", true, RSquaredMetric::new); + private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, RSquaredMetric::new); public static RSquaredMetric fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - public RSquaredMetric() { - - } + public RSquaredMetric() {} @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 76363a13b4782..e58a35680dced 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -142,6 +142,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -1852,17 +1853,25 @@ public void testEvaluateDataFrame_Regression() throws IOException { new EvaluateDataFrameRequest( regressionIndex, null, - new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric())); + new Regression( + actualRegression, + predictedRegression, + new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), new RSquaredMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(3)); MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME); assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME)); assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9)); + MeanSquaredLogarithmicErrorMetric.Result msleResult = + evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); + assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME)); + assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9)); + RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME); assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME)); assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 4c3bb47d38617..8cadfdc56ea5e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -61,6 +61,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -701,7 +702,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(64, namedXContents.size()); + assertEquals(66, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -748,7 +749,7 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, hasItems( registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME), @@ -762,8 +763,9 @@ public void testProvidedNamedXContents() { Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), + registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); - assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, hasItems( registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME), @@ -777,6 +779,7 @@ public void testProvidedNamedXContents() { Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), + registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index a5fee222587ed..a1472e2e5ab91 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -161,6 +161,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -3570,7 +3571,8 @@ public void testEvaluateDataFrame_Regression() throws Exception { "predicted_value", // <3> // Evaluation metrics // <4> new MeanSquaredErrorMetric(), // <5> - new RSquaredMetric()); // <6> + new MeanSquaredLogarithmicErrorMetric(1.0), // <6> + new RSquaredMetric()); // <7> // end::evaluate-data-frame-evaluation-regression EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); @@ -3580,11 +3582,16 @@ public void testEvaluateDataFrame_Regression() throws Exception { MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1> double meanSquaredError = meanSquaredErrorResult.getError(); // <2> - RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3> - double rSquared = rSquaredResult.getValue(); // <4> + MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult = + response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3> + double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4> + + RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <5> + double rSquared = rSquaredResult.getValue(); // <6> // end::evaluate-data-frame-results-regression assertThat(meanSquaredError, closeTo(0.021, 1e-3)); + assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3)); assertThat(rSquared, closeTo(0.941, 1e-3)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricResultTests.java new file mode 100644 index 0000000000000..4e96337503dbc --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricResultTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MeanSquaredLogarithmicErrorMetricResultTests extends AbstractXContentTestCase { + + public static MeanSquaredLogarithmicErrorMetric.Result randomResult() { + return new MeanSquaredLogarithmicErrorMetric.Result(randomDouble()); + } + + @Override + protected MeanSquaredLogarithmicErrorMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected MeanSquaredLogarithmicErrorMetric.Result doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredLogarithmicErrorMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricTests.java new file mode 100644 index 0000000000000..e51e8b7f165a3 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MeanSquaredLogarithmicErrorMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MeanSquaredLogarithmicErrorMetric createTestInstance() { + return new MeanSquaredLogarithmicErrorMetric(randomBoolean() ? randomDouble() : null); + } + + @Override + protected MeanSquaredLogarithmicErrorMetric doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredLogarithmicErrorMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java index 5d2a614663d31..a4b862f0e2e0f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java @@ -41,6 +41,9 @@ public static Regression createRandom() { if (randomBoolean()) { metrics.add(new MeanSquaredErrorMetric()); } + if (randomBoolean()) { + metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance()); + } if (randomBoolean()) { metrics.add(new RSquaredMetric()); } diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc index 57a82d1c7132f..d1eb6ffadc9c5 100644 --- a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -68,7 +68,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-regression] <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example. <4> The remaining parameters are the metrics to be calculated based on the two fields described above <5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error] -<6> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared] +<6> Mean squared logarithmic error +<7> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared] include::../execution.asciidoc[] @@ -123,5 +124,7 @@ include-tagged::{doc-tests-file}[{api}-results-regression] <1> Fetching mean squared error metric by name <2> Fetching the actual mean squared error value -<3> Fetching R squared metric by name -<4> Fetching the actual R squared value +<3> Fetching mean squared logarithmic error metric by name +<4> Fetching the actual mean squared logarithmic error value +<5> Fetching R squared metric by name +<6> Fetching the actual R squared value diff --git a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc index ad4f0467750fd..9e1ceffec137f 100644 --- a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc @@ -130,6 +130,10 @@ which outputs a prediction of values. (Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value. For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article]. + `mean_squared_logarithmic_error`::: + (Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual + (`ground truth`) value. + `r_squared`::: (Optional, object) Proportion of the variance in the dependent variable that is predictable from the independent variables. For more information, read https://en.wikipedia.org/wiki/Coefficient_of_determination[this wiki article]. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 42e530a7a602d..315569c8cdb4c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; @@ -95,6 +96,9 @@ public List getNamedXContentParsers() { new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)), MeanSquaredError::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)), + MeanSquaredLogarithmicError::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)), RSquared::fromXContent) @@ -144,6 +148,9 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(EvaluationMetric.class, registeredMetricName(Regression.NAME, MeanSquaredError.NAME), MeanSquaredError::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), + MeanSquaredLogarithmicError::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, registeredMetricName(Regression.NAME, RSquared.NAME), RSquared::new), @@ -175,6 +182,9 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(Regression.NAME, MeanSquaredError.NAME), MeanSquaredError.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), + MeanSquaredLogarithmicError.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(Regression.NAME, RSquared.NAME), RSquared.Result::new) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index 982c68dfb58af..1ad251039c44d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -40,7 +40,9 @@ public class MeanSquaredError implements EvaluationMetric { public static final ParseField NAME = new ParseField("mean_squared_error"); - private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;"; + private static final String PAINLESS_TEMPLATE = + "def diff = doc[''{0}''].value - doc[''{1}''].value;" + + "return diff * diff;"; private static final String AGG_NAME = "regression_" + NAME.getPreferredName(); private static String buildScript(Object...args) { @@ -141,6 +143,10 @@ public String getMetricName() { return NAME.getPreferredName(); } + public double getError() { + return error; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeDouble(error); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java new file mode 100644 index 0000000000000..5a0a796c7c70a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + +/** + * Calculates the mean squared error between two known numerical fields. + * + * equation: msle = 1/n * Σ(log(y + offset) - log(y´ + offset))^2 + * where offset is used to make sure the argument to log function is always positive + */ +public class MeanSquaredLogarithmicError implements EvaluationMetric { + + public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error"); + + public static final ParseField OFFSET = new ParseField("offset"); + private static final double DEFAULT_OFFSET = 1.0; + + private static final String PAINLESS_TEMPLATE = + "def offset = {2};" + + "def diff = Math.log(doc[''{0}''].value + offset) - Math.log(doc[''{1}''].value + offset);" + + "return diff * diff;"; + private static final String AGG_NAME = "regression_" + NAME.getPreferredName(); + + private static String buildScript(Object...args) { + return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MeanSquaredLogarithmicError((Double) args[0])); + + static { + PARSER.declareDouble(optionalConstructorArg(), OFFSET); + } + + public static MeanSquaredLogarithmicError fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final double offset; + private EvaluationMetricResult result; + + public MeanSquaredLogarithmicError(StreamInput in) throws IOException { + this.offset = in.readDouble(); + } + + public MeanSquaredLogarithmicError(@Nullable Double offset) { + this.offset = offset != null ? offset : DEFAULT_OFFSET; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { + if (result != null) { + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); + } + return Tuple.tuple( + Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, offset)))), + Collections.emptyList()); + } + + @Override + public void process(Aggregations aggs) { + NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME); + result = value == null ? new Result(0.0) : new Result(value.value()); + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Regression.NAME, NAME); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(offset); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(OFFSET.getPreferredName(), offset); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MeanSquaredLogarithmicError that = (MeanSquaredLogarithmicError) o; + return this.offset == that.offset; + } + + @Override + public int hashCode() { + return Double.hashCode(offset); + } + + public static class Result implements EvaluationMetricResult { + + private static final String ERROR = "error"; + private final double error; + + public Result(double error) { + this.error = error; + } + + public Result(StreamInput in) throws IOException { + this.error = in.readDouble(); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Regression.NAME, NAME); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + public double getError() { + return error; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(error); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ERROR, error); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result other = (Result)o; + return error == other.error; + } + + @Override + public int hashCode() { + return Objects.hashCode(error); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index ac142631ab968..0ef32f2c04e7a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -45,7 +45,9 @@ public class RSquared implements EvaluationMetric { public static final ParseField NAME = new ParseField("r_squared"); - private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;"; + private static final String PAINLESS_TEMPLATE = + "def diff = doc[''{0}''].value - doc[''{1}''].value;" + + "return diff * diff;"; private static final String SS_RES = "residual_sum_of_squares"; private static String buildScript(Object... args) { @@ -156,6 +158,10 @@ public String getMetricName() { return NAME.getPreferredName(); } + public double getValue() { + return value; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeDouble(value); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java index 437734ac763be..7c8860770236e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import java.util.List; @@ -37,6 +38,7 @@ protected Response createTestInstance() { RecallResultTests.createRandom(), MulticlassConfusionMatrixResultTests.createRandom(), new MeanSquaredError.Result(randomDouble()), + new MeanSquaredLogarithmicError.Result(randomDouble()), new RSquared.Result(randomDouble())); return new Response(evaluationName, randomSubsetOf(metrics)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorTests.java new file mode 100644 index 0000000000000..a45bb72fd3ced --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; +import static org.hamcrest.Matchers.equalTo; + +public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCase { + + @Override + protected MeanSquaredLogarithmicError doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredLogarithmicError.fromXContent(parser); + } + + @Override + protected MeanSquaredLogarithmicError createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MeanSquaredLogarithmicError::new; + } + + public static MeanSquaredLogarithmicError createRandom() { + return new MeanSquaredLogarithmicError(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null); + } + + public void testEvaluate() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockSingleValue("regression_mean_squared_logarithmic_error", 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + + MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError((Double) null); + msle.process(aggs); + + EvaluationMetricResult result = msle.getResult().get(); + String expected = "{\"error\":0.8123}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + public void testEvaluate_GivenMissingAggs() { + Aggregations aggs = new Aggregations(Collections.singletonList( + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + + MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError((Double) null); + msle.process(aggs); + + EvaluationMetricResult result = msle.getResult().get(); + assertThat(result, equalTo(new MeanSquaredLogarithmicError.Result(0.0))); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java new file mode 100644 index 0000000000000..21f869ee9527f --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; +import org.junit.After; +import org.junit.Before; + +import java.util.List; + +import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + private static final String HOUSES_DATA_INDEX = "test-evaluate-houses-index"; + + private static final String PRICE_FIELD = "price"; + private static final String PRICE_PREDICTION_FIELD = "price_prediction"; + + @Before + public void setup() { + createHousesIndex(HOUSES_DATA_INDEX); + indexHousesData(HOUSES_DATA_INDEX); + } + + @After + public void cleanup() { + cleanUp(); + } + + public void testEvaluate_DefaultMetrics() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, null)); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); + assertThat( + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName())); + } + + public void testEvaluate_AllMetrics() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + HOUSES_DATA_INDEX, + new Regression( + PRICE_FIELD, + PRICE_PREDICTION_FIELD, + List.of(new MeanSquaredError(), new MeanSquaredLogarithmicError((Double) null), new RSquared()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); + assertThat( + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + contains( + MeanSquaredError.NAME.getPreferredName(), + MeanSquaredLogarithmicError.NAME.getPreferredName(), + RSquared.NAME.getPreferredName())); + } + + public void testEvaluate_MeanSquaredError() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new MeanSquaredError()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName())); + assertThat(mseResult.getError(), equalTo(1000000.0)); + } + + public void testEvaluate_MeanSquaredLogarithmicError() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + HOUSES_DATA_INDEX, + new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new MeanSquaredLogarithmicError((Double) null)))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName())); + assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1001), 2), 10E-6)); + } + + public void testEvaluate_RSquared() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new RSquared()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + RSquared.Result rSquaredResult = (RSquared.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(rSquaredResult.getMetricName(), equalTo(RSquared.NAME.getPreferredName())); + assertThat(rSquaredResult.getValue(), equalTo(0.0)); + } + + private static void createHousesIndex(String indexName) { + client().admin().indices().prepareCreate(indexName) + .setMapping( + PRICE_FIELD, "type=double", + PRICE_PREDICTION_FIELD, "type=double") + .get(); + } + + private static void indexHousesData(String indexName) { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < 100; i++) { + bulkRequestBuilder.add( + new IndexRequest(indexName) + .source( + PRICE_FIELD, 1000, + PRICE_PREDICTION_FIELD, 0)); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 2b16b79ac84b4..5ed85f3f6a287 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -847,6 +847,26 @@ setup: } - match: { regression.mean_squared_error.error: 28.67749840974834 } + - is_false: regression.mean_squared_logarithmic_error.value + - is_false: regression.r_squared.value +--- +"Test regression mean_squared_logarithmic_error": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred", + "metrics": { "mean_squared_logarithmic_error": { "offset": 6.0 } } + } + } + } + + - match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 } + - is_false: regression.mean_squared_error.value - is_false: regression.r_squared.value --- "Test regression r_squared": @@ -865,6 +885,7 @@ setup: } - match: { regression.r_squared.value: 0.8551031778603486 } - is_false: regression.mean_squared_error + - is_false: regression.mean_squared_logarithmic_error.value --- "Test regression with null metrics": - do: @@ -882,6 +903,7 @@ setup: - match: { regression.mean_squared_error.error: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } + - is_false: regression.mean_squared_logarithmic_error.value --- "Test regression given missing actual_field": - do: