Skip to content

Commit 3657ce5

Browse files
authored
[ML] Allow unbounded num_top_classes in classification analysis (#1526) (#1529)
1 parent 93a2446 commit 3657ce5

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

include/api/CDataFrameAnalysisConfigReader.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class API_EXPORT CDataFrameAnalysisConfigReader {
8484
bool fallback(bool value) const;
8585
//! Get an unsigned integer parameter.
8686
std::size_t fallback(std::size_t value) const;
87+
//! Get a signed integer parameter.
88+
std::ptrdiff_t fallback(std::ptrdiff_t value) const;
8789
//! Get a floating point parameter.
8890
double fallback(double value) const;
8991
//! Get a string parameter.

include/api/CDataFrameTrainBoostedTreeClassifierRunner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
8686
core::CRapidJsonConcurrentLineWriter& writer) const;
8787

8888
private:
89-
std::size_t m_NumTopClasses;
89+
std::ptrdiff_t m_NumTopClasses;
9090
EPredictionFieldType m_PredictionFieldType;
9191
mutable CInferenceModelMetadata m_InferenceModelMetadata;
9292
};

lib/api/CDataFrameAnalysisConfigReader.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ std::size_t CDataFrameAnalysisConfigReader::CParameter::fallback(std::size_t val
112112
return m_Value->GetUint64();
113113
}
114114

115+
std::ptrdiff_t CDataFrameAnalysisConfigReader::CParameter::fallback(std::ptrdiff_t value) const {
116+
if (m_Value == nullptr) {
117+
return value;
118+
}
119+
if (m_Value->IsInt64() == false) {
120+
this->handleFatal();
121+
return value;
122+
}
123+
return m_Value->GetInt64();
124+
}
125+
115126
double CDataFrameAnalysisConfigReader::CParameter::fallback(double value) const {
116127
if (m_Value == nullptr) {
117128
return value;

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
8181
: CDataFrameTrainBoostedTreeRunner{
8282
spec, parameters, loss(parameters[NUM_CLASSES].as<std::size_t>())} {
8383

84-
m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
84+
m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::ptrdiff_t{0});
8585
m_PredictionFieldType =
8686
parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString);
8787
this->boostedTreeFactory().classAssignmentObjective(
@@ -138,14 +138,18 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
138138
writer.Key(IS_TRAINING_FIELD_NAME);
139139
writer.Bool(maths::CDataFrameUtils::isMissing(actualClassId) == false);
140140

141-
if (m_NumTopClasses > 0) {
141+
if (m_NumTopClasses != 0) {
142142
TSizeVec classIds(scores.size());
143143
std::iota(classIds.begin(), classIds.end(), 0);
144144
std::sort(classIds.begin(), classIds.end(),
145145
[&scores](std::size_t lhs, std::size_t rhs) {
146146
return scores[lhs] > scores[rhs];
147147
});
148-
classIds.resize(std::min(classIds.size(), m_NumTopClasses));
148+
// -1 is a special value meaning "output all the classes"
149+
classIds.resize(m_NumTopClasses == -1
150+
? classIds.size()
151+
: std::min(classIds.size(),
152+
static_cast<std::size_t>(m_NumTopClasses)));
149153
writer.Key(TOP_CLASSES_FIELD_NAME);
150154
writer.StartArray();
151155
for (std::size_t i : classIds) {

0 commit comments

Comments
 (0)