@@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
4545 BOOST_TEST_REQUIRE (errors[0 ] == " Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]." );
4646}
4747
48- BOOST_AUTO_TEST_CASE (testWriteOneRow) {
48+ template <typename T>
49+ void testWriteOneRow (const std::string& dependentVariableField,
50+ const std::string& dependentVariableType,
51+ T (rapidjson::Value::*extract)() const ,
52+ const std::vector<T>& expectedPredictions) {
4953 // Prepare input data frame
50- const TStrVec columnNames{" x1" , " x2" , " x3" , " x4" , " x5" , " x5_prediction" };
51- const TStrVec categoricalColumns{" x1" , " x2" , " x5" };
54+ const std::string predictionField = dependentVariableField + " _prediction" ;
55+ const TStrVec columnNames{" x1" , " x2" , " x3" , " x4" , " x5" , predictionField};
56+ const TStrVec categoricalColumns{" x1" , " x2" , " x3" , " x4" , " x5" };
5257 const TStrVecVec rows{{" a" , " b" , " 1.0" , " 1.0" , " cat" , " -1.0" },
53- {" a" , " b" , " 2 .0" , " 2 .0" , " cat" , " -0.5" },
54- {" a" , " b" , " 5.0" , " 5 .0" , " dog" , " -0.1" },
55- {" c" , " d" , " 5.0" , " 5 .0" , " dog" , " 1.0" },
56- {" e" , " f" , " 5.0" , " 5 .0" , " dog" , " 1.5" }};
58+ {" a" , " b" , " 1 .0" , " 1 .0" , " cat" , " -0.5" },
59+ {" a" , " b" , " 5.0" , " 0 .0" , " dog" , " -0.1" },
60+ {" c" , " d" , " 5.0" , " 0 .0" , " dog" , " 1.0" },
61+ {" e" , " f" , " 5.0" , " 0 .0" , " dog" , " 1.5" }};
5762 std::unique_ptr<core::CDataFrame> frame =
5863 core::makeMainStorageDataFrame (columnNames.size ()).first ;
5964 frame->columnNames (columnNames);
@@ -67,10 +72,13 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
6772
6873 // Create classification analysis runner object
6974 const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec (
70- " classification" , " x5 " , rows.size (), columnNames. size (), 13000000 , 0 , 0 ,
71- categoricalColumns)};
75+ " classification" , dependentVariableField , rows.size (),
76+ columnNames. size (), 13000000 , 0 , 0 , categoricalColumns)};
7277 rapidjson::Document jsonParameters;
73- jsonParameters.Parse (" {\" dependent_variable\" : \" x5\" }" );
78+ jsonParameters.Parse (" {"
79+ " \" dependent_variable\" : \" " + dependentVariableField + " \" ,"
80+ " \" dependent_variable_type\" : \" " + dependentVariableType + " \" "
81+ " }" );
7482 const auto parameters{
7583 api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader ().read (jsonParameters)};
7684 api::CDataFrameTrainBoostedTreeClassifierRunner runner (*spec, parameters);
@@ -83,10 +91,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
8391
8492 frame->readRows (1 , [&](TRowItr beginRows, TRowItr endRows) {
8593 const auto columnHoldingDependentVariable{
86- std::find (columnNames.begin (), columnNames.end (), " x5 " ) -
94+ std::find (columnNames.begin (), columnNames.end (), dependentVariableField ) -
8795 columnNames.begin ()};
8896 const auto columnHoldingPrediction{
89- std::find (columnNames.begin (), columnNames.end (), " x5_prediction " ) -
97+ std::find (columnNames.begin (), columnNames.end (), predictionField ) -
9098 columnNames.begin ()};
9199 for (auto row = beginRows; row != endRows; ++row) {
92100 runner.writeOneRow (*frame, columnHoldingDependentVariable,
@@ -95,17 +103,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
95103 });
96104 }
97105 // Verify results
98- const TStrVec expectedPredictions{" cat" , " cat" , " cat" , " dog" , " dog" };
99106 rapidjson::Document arrayDoc;
100107 arrayDoc.Parse <rapidjson::kParseDefaultFlags >(output.str ().c_str ());
101108 BOOST_TEST_REQUIRE (arrayDoc.IsArray ());
102109 BOOST_TEST_REQUIRE (arrayDoc.Size () == rows.size ());
110+ BOOST_TEST_REQUIRE (arrayDoc.Size () == expectedPredictions.size ());
103111 for (std::size_t i = 0 ; i < arrayDoc.Size (); ++i) {
104112 BOOST_TEST_CONTEXT (" Result for row " << i) {
105113 const rapidjson::Value& object = arrayDoc[rapidjson::SizeType (i)];
106114 BOOST_TEST_REQUIRE (object.IsObject ());
107- BOOST_TEST_REQUIRE (object.HasMember (" x5_prediction " ));
108- BOOST_TEST_REQUIRE (object[" x5_prediction " ]. GetString () ==
115+ BOOST_TEST_REQUIRE (object.HasMember (predictionField ));
116+ BOOST_TEST_REQUIRE (( object[predictionField].*extract) () ==
109117 expectedPredictions[i]);
110118 BOOST_TEST_REQUIRE (object.HasMember (" prediction_probability" ));
111119 BOOST_TEST_REQUIRE (object[" prediction_probability" ].GetDouble () > 0.5 );
@@ -115,4 +123,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
115123 }
116124}
117125
126+ BOOST_AUTO_TEST_CASE (testWriteOneRow_DependentVariableIsInt) {
127+ testWriteOneRow (" x3" , " int" , &rapidjson::Value::GetInt, {1 , 1 , 1 , 5 , 5 });
128+ }
129+
130+ BOOST_AUTO_TEST_CASE (testWriteOneRow_DependentVariableIsBool) {
131+ testWriteOneRow (" x4" , " bool" , &rapidjson::Value::GetBool,
132+ {true , true , true , false , false });
133+ }
134+
135+ BOOST_AUTO_TEST_CASE (testWriteOneRow_DependentVariableIsString) {
136+ testWriteOneRow (" x5" , " string" , &rapidjson::Value::GetString,
137+ {" cat" , " cat" , " cat" , " dog" , " dog" });
138+ }
139+
140+ BOOST_AUTO_TEST_CASE (testWriteOneRow_DependentVariableTypeMissing) {
141+ testWriteOneRow (" x5" , " " , &rapidjson::Value::GetString,
142+ {" cat" , " cat" , " cat" , " dog" , " dog" });
143+ }
144+
118145BOOST_AUTO_TEST_SUITE_END ()
0 commit comments