@@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
4545 BOOST_TEST_REQUIRE (errors[0 ] == " Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]." );
4646}
4747
48- BOOST_AUTO_TEST_CASE (testWriteOneRow) {
48+ template <typename T>
49+ void testWriteOneRow (const std::string& dependentVariableField,
50+ const std::string& predictionFieldType,
51+ T (rapidjson::Value::*extract)() const ,
52+ const std::vector<T>& expectedPredictions) {
4953 // Prepare input data frame
50- const TStrVec columnNames{" x1" , " x2" , " x3" , " x4" , " x5" , " x5_prediction" };
51- const TStrVec categoricalColumns{" x1" , " x2" , " x5" };
54+ const std::string predictionField = dependentVariableField + " _prediction" ;
55+ const TStrVec columnNames{" x1" , " x2" , " x3" , " x4" , " x5" , predictionField};
56+ const TStrVec categoricalColumns{" x1" , " x2" , " x3" , " x4" , " x5" };
5257 const TStrVecVec rows{{" a" , " b" , " 1.0" , " 1.0" , " cat" , " -1.0" },
53- {" a" , " b" , " 2 .0" , " 2 .0" , " cat" , " -0.5" },
54- {" a" , " b" , " 5.0" , " 5 .0" , " dog" , " -0.1" },
55- {" c" , " d" , " 5.0" , " 5 .0" , " dog" , " 1.0" },
56- {" e" , " f" , " 5.0" , " 5 .0" , " dog" , " 1.5" }};
58+ {" a" , " b" , " 1 .0" , " 1 .0" , " cat" , " -0.5" },
59+ {" a" , " b" , " 5.0" , " 0 .0" , " dog" , " -0.1" },
60+ {" c" , " d" , " 5.0" , " 0 .0" , " dog" , " 1.0" },
61+ {" e" , " f" , " 5.0" , " 0 .0" , " dog" , " 1.5" }};
5762 std::unique_ptr<core::CDataFrame> frame =
5863 core::makeMainStorageDataFrame (columnNames.size ()).first ;
5964 frame->columnNames (columnNames);
@@ -67,10 +72,21 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
6772
6873 // Create classification analysis runner object
6974 const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec (
70- " classification" , " x5 " , rows.size (), columnNames. size (), 13000000 , 0 , 0 ,
71- categoricalColumns)};
75+ " classification" , dependentVariableField , rows.size (),
76+ columnNames. size (), 13000000 , 0 , 0 , categoricalColumns)};
7277 rapidjson::Document jsonParameters;
73- jsonParameters.Parse (" {\" dependent_variable\" : \" x5\" }" );
78+ if (predictionFieldType.empty ()) {
79+ jsonParameters.Parse (" {\" dependent_variable\" : \" " + dependentVariableField + " \" }" );
80+ } else {
81+ jsonParameters.Parse (" {"
82+ " \" dependent_variable\" : \" " +
83+ dependentVariableField +
84+ " \" ,"
85+ " \" prediction_field_type\" : \" " +
86+ predictionFieldType +
87+ " \" "
88+ " }" );
89+ }
7490 const auto parameters{
7591 api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader ().read (jsonParameters)};
7692 api::CDataFrameTrainBoostedTreeClassifierRunner runner (*spec, parameters);
@@ -83,10 +99,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
8399
84100 frame->readRows (1 , [&](TRowItr beginRows, TRowItr endRows) {
85101 const auto columnHoldingDependentVariable{
86- std::find (columnNames.begin (), columnNames.end (), " x5 " ) -
102+ std::find (columnNames.begin (), columnNames.end (), dependentVariableField ) -
87103 columnNames.begin ()};
88104 const auto columnHoldingPrediction{
89- std::find (columnNames.begin (), columnNames.end (), " x5_prediction " ) -
105+ std::find (columnNames.begin (), columnNames.end (), predictionField ) -
90106 columnNames.begin ()};
91107 for (auto row = beginRows; row != endRows; ++row) {
92108 runner.writeOneRow (*frame, columnHoldingDependentVariable,
@@ -95,17 +111,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
95111 });
96112 }
97113 // Verify results
98- const TStrVec expectedPredictions{" cat" , " cat" , " cat" , " dog" , " dog" };
99114 rapidjson::Document arrayDoc;
100115 arrayDoc.Parse <rapidjson::kParseDefaultFlags >(output.str ().c_str ());
101116 BOOST_TEST_REQUIRE (arrayDoc.IsArray ());
102117 BOOST_TEST_REQUIRE (arrayDoc.Size () == rows.size ());
118+ BOOST_TEST_REQUIRE (arrayDoc.Size () == expectedPredictions.size ());
103119 for (std::size_t i = 0 ; i < arrayDoc.Size (); ++i) {
104120 BOOST_TEST_CONTEXT (" Result for row " << i) {
105121 const rapidjson::Value& object = arrayDoc[rapidjson::SizeType (i)];
106122 BOOST_TEST_REQUIRE (object.IsObject ());
107- BOOST_TEST_REQUIRE (object.HasMember (" x5_prediction " ));
108- BOOST_TEST_REQUIRE (object[" x5_prediction " ]. GetString () ==
123+ BOOST_TEST_REQUIRE (object.HasMember (predictionField ));
124+ BOOST_TEST_REQUIRE (( object[predictionField].*extract) () ==
109125 expectedPredictions[i]);
110126 BOOST_TEST_REQUIRE (object.HasMember (" prediction_probability" ));
111127 BOOST_TEST_REQUIRE (object[" prediction_probability" ].GetDouble () > 0.5 );
@@ -115,4 +131,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
115131 }
116132}
117133
134+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsInt) {
135+ testWriteOneRow (" x3" , " int" , &rapidjson::Value::GetInt, {1 , 1 , 1 , 5 , 5 });
136+ }
137+
138+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsBool) {
139+ testWriteOneRow (" x4" , " bool" , &rapidjson::Value::GetBool,
140+ {true , true , true , false , false });
141+ }
142+
143+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsString) {
144+ testWriteOneRow (" x5" , " string" , &rapidjson::Value::GetString,
145+ {" cat" , " cat" , " cat" , " dog" , " dog" });
146+ }
147+
148+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsMissing) {
149+ testWriteOneRow (" x5" , " " , &rapidjson::Value::GetString,
150+ {" cat" , " cat" , " cat" , " dog" , " dog" });
151+ }
152+
118153BOOST_AUTO_TEST_SUITE_END ()
0 commit comments