diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 088e90c78a..855d045fcc 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -47,6 +47,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[ === Bug Fixes * Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].) +* Prevent prediction_field_name clashing with other fields in ml results. +(See {ml-pull}861[#861].) == {es} version 7.5.0 diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index b6c619e9bd..f4609d595b 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -68,6 +68,13 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier this->dependentVariableFieldName()) == categoricalFieldNames.end()) { HANDLE_FATAL(<< "Input error: trying to perform classification with numeric target."); } + const std::set predictionFieldNameBlacklist{ + IS_TRAINING_FIELD_NAME, PREDICTION_PROBABILITY_FIELD_NAME, TOP_CLASSES_FIELD_NAME}; + if (predictionFieldNameBlacklist.count(this->predictionFieldName()) > 0) { + HANDLE_FATAL(<< "Input error: prediction_field_name must not be equal to any of " + << core::CContainerPrinter::print(predictionFieldNameBlacklist) + << "."); + } } CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifierRunner( diff --git a/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc b/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc index e9fd1c0ae6..ab053b11a5 100644 --- a/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc @@ -52,6 +52,12 @@ CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegression this->dependentVariableFieldName()) != categoricalFieldNames.end()) { HANDLE_FATAL(<< "Input error: trying to perform regression with categorical target."); } + const std::set predictionFieldNameBlacklist{IS_TRAINING_FIELD_NAME}; + if (predictionFieldNameBlacklist.count(this->predictionFieldName()) > 0) { + HANDLE_FATAL(<< "Input error: prediction_field_name must not be equal to any of " + << core::CContainerPrinter::print(predictionFieldNameBlacklist) + << "."); + } } CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegressionRunner( diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index 4b99401c07..dc60bf8bbf 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -25,6 +25,26 @@ using TStrVec = std::vector; using TStrVecVec = std::vector; } +BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { + TStrVec errors; + auto errorHandler = [&errors](std::string error) { errors.push_back(error); }; + core::CLogger::CScopeSetFatalErrorHandler scope{errorHandler}; + + const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec( + "classification", "dep_var", 5, 6, 13000000, 0, 0, {"dep_var"})}; + rapidjson::Document jsonParameters; + jsonParameters.Parse("{" + " \"dependent_variable\": \"dep_var\"," + " \"prediction_field_name\": \"is_training\"" + "}"); + const auto parameters{ + api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; + api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters); + + BOOST_TEST_REQUIRE(errors.size() == 1); + BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]."); +} + BOOST_AUTO_TEST_CASE(testWriteOneRow) { // Prepare input data frame const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"}; diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeRegressionRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeRegressionRunnerTest.cc new file mode 100644 index 0000000000..10856d9f2f --- /dev/null +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeRegressionRunnerTest.cc @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#include + +#include +#include + +#include + +#include + +#include +#include + +BOOST_AUTO_TEST_SUITE(CDataFrameTrainBoostedTreeRegressionRunnerTest) + +using namespace ml; +namespace { +using TStrVec = std::vector; +} + +BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { + TStrVec errors; + auto errorHandler = [&errors](std::string error) { errors.push_back(error); }; + core::CLogger::CScopeSetFatalErrorHandler scope{errorHandler}; + + const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec( + "regression", "dep_var", 5, 6, 13000000, 0, 0)}; + rapidjson::Document jsonParameters; + jsonParameters.Parse("{" + " \"dependent_variable\": \"dep_var\"," + " \"prediction_field_name\": \"is_training\"" + "}"); + const auto parameters{ + api::CDataFrameTrainBoostedTreeRegressionRunner::parameterReader().read(jsonParameters)}; + api::CDataFrameTrainBoostedTreeRegressionRunner runner(*spec, parameters); + + BOOST_TEST_REQUIRE(errors.size() == 1); + BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training]."); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/lib/api/unittest/Makefile b/lib/api/unittest/Makefile index 1a99a4af8a..3803f7da92 100644 --- a/lib/api/unittest/Makefile +++ b/lib/api/unittest/Makefile @@ -31,6 +31,7 @@ SRCS=\ CDataFrameAnalyzerOutlierTest.cc \ CDataFrameAnalyzerTrainingTest.cc \ CDataFrameTrainBoostedTreeClassifierRunnerTest.cc \ + CDataFrameTrainBoostedTreeRegressionRunnerTest.cc \ CDataFrameMockAnalysisRunner.cc \ CDetectionRulesJsonParserTest.cc \ CFieldConfigTest.cc \