Skip to content

Commit 1eb3b44

Browse files
committed
Emit predicted category using an appropriate JSON type.
1 parent 617e5b9 commit 1eb3b44

File tree

4 files changed

+69
-17
lines changed

4 files changed

+69
-17
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
5252
(See {ml-pull}818[#818].)
5353
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
5454
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
55+
* Emit `prediction_field_name` in ml results using the type of a `dependent_variable`.
56+
(See {ml-pull}877[#877].)
5557

5658
=== Bug Fixes
5759
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)

include/api/CDataFrameTrainBoostedTreeClassifierRunner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
4444
const TRowRef& row,
4545
core::CRapidJsonConcurrentLineWriter& writer) const;
4646

47+
//! Write the predicted category value as string, int or bool.
48+
void writePredictedCategoryValue(const std::string& categoryValue,
49+
core::CRapidJsonConcurrentLineWriter& writer) const;
50+
4751
//! \return A serialisable definition of the trained classification model.
4852
TInferenceModelDefinitionUPtr
4953
inferenceModelDefinition(const TStrVec& fieldNames,
@@ -55,6 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
5559

5660
private:
5761
std::size_t m_NumTopClasses;
62+
std::string m_DependentVariableType;
5863
};
5964

6065
//! \brief Makes a core::CDataFrame boosted tree classification runner.

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using TSizeVec = std::vector<std::size_t>;
3232

3333
// Configuration
3434
const std::string NUM_TOP_CLASSES{"num_top_classes"};
35+
const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"};
3536
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};
3637

3738
// Output
@@ -47,6 +48,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
4748
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
4849
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
4950
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
51+
theReader.addParameter(DEPENDENT_VARIABLE_TYPE,
52+
CDataFrameAnalysisConfigReader::E_OptionalParameter);
5053
theReader.addParameter(BALANCED_CLASS_LOSS,
5154
CDataFrameAnalysisConfigReader::E_OptionalParameter);
5255
return theReader;
@@ -60,6 +63,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
6063
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {
6164

6265
m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
66+
m_DependentVariableType =
67+
parameters[DEPENDENT_VARIABLE_TYPE].fallback(std::string("string"));
6368
this->boostedTreeFactory().balanceClassTrainingLoss(
6469
parameters[BALANCED_CLASS_LOSS].fallback(true));
6570

@@ -119,7 +124,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
119124

120125
writer.StartObject();
121126
writer.Key(this->predictionFieldName());
122-
writer.String(categoryValues[predictedCategoryId]);
127+
writePredictedCategoryValue(categoryValues[predictedCategoryId], writer);
123128
writer.Key(PREDICTION_PROBABILITY_FIELD_NAME);
124129
writer.Double(probabilityOfCategory[predictedCategoryId]);
125130
writer.Key(IS_TRAINING_FIELD_NAME);
@@ -135,7 +140,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
135140
for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) {
136141
writer.StartObject();
137142
writer.Key(CLASS_NAME_FIELD_NAME);
138-
writer.String(categoryValues[categoryIds[i]]);
143+
writePredictedCategoryValue(categoryValues[categoryIds[i]], writer);
139144
writer.Key(CLASS_PROBABILITY_FIELD_NAME);
140145
writer.Double(probabilityOfCategory[i]);
141146
writer.EndObject();
@@ -158,6 +163,19 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
158163
columnHoldingPrediction, row, writer);
159164
}
160165

166+
void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
167+
const std::string& categoryValue,
168+
core::CRapidJsonConcurrentLineWriter& writer) const {
169+
170+
if (m_DependentVariableType == "int") {
171+
writer.Int(std::stoi(categoryValue));
172+
} else if (m_DependentVariableType == "bool") {
173+
writer.Bool(std::stoi(categoryValue) == 1);
174+
} else {
175+
writer.String(categoryValue);
176+
}
177+
}
178+
161179
CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
162180
CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame,
163181
std::size_t dependentVariableColumn) const {

lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
4545
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
4646
}
4747

48-
BOOST_AUTO_TEST_CASE(testWriteOneRow) {
48+
template<typename T>
49+
void testWriteOneRow(const std::string& dependentVariableField,
50+
const std::string& dependentVariableType,
51+
T (rapidjson::Value::*extract)() const,
52+
const std::vector<T>& expectedPredictions) {
4953
// Prepare input data frame
50-
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
51-
const TStrVec categoricalColumns{"x1", "x2", "x5"};
54+
const std::string predictionField = dependentVariableField + "_prediction";
55+
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField};
56+
const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"};
5257
const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"},
53-
{"a", "b", "2.0", "2.0", "cat", "-0.5"},
54-
{"a", "b", "5.0", "5.0", "dog", "-0.1"},
55-
{"c", "d", "5.0", "5.0", "dog", "1.0"},
56-
{"e", "f", "5.0", "5.0", "dog", "1.5"}};
58+
{"a", "b", "1.0", "1.0", "cat", "-0.5"},
59+
{"a", "b", "5.0", "0.0", "dog", "-0.1"},
60+
{"c", "d", "5.0", "0.0", "dog", "1.0"},
61+
{"e", "f", "5.0", "0.0", "dog", "1.5"}};
5762
std::unique_ptr<core::CDataFrame> frame =
5863
core::makeMainStorageDataFrame(columnNames.size()).first;
5964
frame->columnNames(columnNames);
@@ -67,10 +72,13 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
6772

6873
// Create classification analysis runner object
6974
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
70-
"classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0,
71-
categoricalColumns)};
75+
"classification", dependentVariableField, rows.size(),
76+
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
7277
rapidjson::Document jsonParameters;
73-
jsonParameters.Parse("{\"dependent_variable\": \"x5\"}");
78+
jsonParameters.Parse("{"
79+
" \"dependent_variable\": \"" + dependentVariableField + "\","
80+
" \"dependent_variable_type\": \"" + dependentVariableType + "\""
81+
"}");
7482
const auto parameters{
7583
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
7684
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
@@ -83,10 +91,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
8391

8492
frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
8593
const auto columnHoldingDependentVariable{
86-
std::find(columnNames.begin(), columnNames.end(), "x5") -
94+
std::find(columnNames.begin(), columnNames.end(), dependentVariableField) -
8795
columnNames.begin()};
8896
const auto columnHoldingPrediction{
89-
std::find(columnNames.begin(), columnNames.end(), "x5_prediction") -
97+
std::find(columnNames.begin(), columnNames.end(), predictionField) -
9098
columnNames.begin()};
9199
for (auto row = beginRows; row != endRows; ++row) {
92100
runner.writeOneRow(*frame, columnHoldingDependentVariable,
@@ -95,17 +103,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
95103
});
96104
}
97105
// Verify results
98-
const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"};
99106
rapidjson::Document arrayDoc;
100107
arrayDoc.Parse<rapidjson::kParseDefaultFlags>(output.str().c_str());
101108
BOOST_TEST_REQUIRE(arrayDoc.IsArray());
102109
BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size());
110+
BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size());
103111
for (std::size_t i = 0; i < arrayDoc.Size(); ++i) {
104112
BOOST_TEST_CONTEXT("Result for row " << i) {
105113
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)];
106114
BOOST_TEST_REQUIRE(object.IsObject());
107-
BOOST_TEST_REQUIRE(object.HasMember("x5_prediction"));
108-
BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() ==
115+
BOOST_TEST_REQUIRE(object.HasMember(predictionField));
116+
BOOST_TEST_REQUIRE((object[predictionField].*extract)() ==
109117
expectedPredictions[i]);
110118
BOOST_TEST_REQUIRE(object.HasMember("prediction_probability"));
111119
BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5);
@@ -115,4 +123,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
115123
}
116124
}
117125

126+
BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsInt) {
127+
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
128+
}
129+
130+
BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsBool) {
131+
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
132+
{true, true, true, false, false});
133+
}
134+
135+
BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsString) {
136+
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
137+
{"cat", "cat", "cat", "dog", "dog"});
138+
}
139+
140+
BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableTypeMissing) {
141+
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
142+
{"cat", "cat", "cat", "dog", "dog"});
143+
}
144+
118145
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)