99
1010#include < core/CDataSearcher.h>
1111
12+ #include < api/CDataFrameAnalysisConfigReader.h>
1213#include < api/CDataFrameAnalysisRunner.h>
1314#include < api/CDataFrameAnalysisSpecification.h>
1415#include < api/ImportExport.h>
@@ -25,11 +26,11 @@ class CBoostedTreeFactory;
2526namespace api {
2627
2728// ! \brief Runs boosted tree regression on a core::CDataFrame.
28- class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRunner {
29+ class API_EXPORT CDataFrameBoostedTreeRunner : public CDataFrameAnalysisRunner {
2930public:
3031 // ! This is not intended to be called directly: use CDataFrameBoostedTreeRunnerFactory.
3132 CDataFrameBoostedTreeRunner (const CDataFrameAnalysisSpecification& spec,
32- const rapidjson::Value& jsonParameters );
33+ const CDataFrameAnalysisConfigReader::CParameters& parameters );
3334
3435 // ! This is not intended to be called directly: use CDataFrameBoostedTreeRunnerFactory.
3536 CDataFrameBoostedTreeRunner (const CDataFrameAnalysisSpecification& spec);
@@ -39,13 +40,20 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
3940 // ! \return The number of columns this adds to the data frame.
4041 std::size_t numberExtraColumns () const override ;
4142
42- // ! Write the prediction for \p row to \p writer.
43- void writeOneRow (const TStrVec& featureNames,
44- TRowRef row,
45- core::CRapidJsonConcurrentLineWriter& writer) const override ;
43+ protected:
44+ using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
45+
46+ protected:
47+ // ! Parameter reader handling parameters that are shared by subclasses.
48+ static CDataFrameAnalysisConfigReader getParameterReader ();
49+ // ! Name of dependent variable field.
50+ const std::string& dependentVariableFieldName () const ;
51+ // ! Name of prediction field.
52+ const std::string& predictionFieldName () const ;
53+ // ! Underlying boosted tree.
54+ const maths::CBoostedTree& boostedTree () const ;
4655
4756private:
48- using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
4957 using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::CBoostedTreeFactory>;
5058 using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;
5159 using TMemoryEstimator = std::function<void (std::int64_t )>;
@@ -71,20 +79,6 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
7179 TBoostedTreeUPtr m_BoostedTree;
7280 std::atomic<std::int64_t > m_Memory;
7381};
74-
75- // ! \brief Makes a core::CDataFrame boosted tree regression runner.
76- class API_EXPORT CDataFrameBoostedTreeRunnerFactory final : public CDataFrameAnalysisRunnerFactory {
77- public:
78- const std::string& name () const override ;
79-
80- private:
81- static const std::string NAME;
82-
83- private:
84- TRunnerUPtr makeImpl (const CDataFrameAnalysisSpecification& spec) const override ;
85- TRunnerUPtr makeImpl (const CDataFrameAnalysisSpecification& spec,
86- const rapidjson::Value& params) const override ;
87- };
8882}
8983}
9084
0 commit comments