@@ -174,7 +174,7 @@ class GBTClassifier @Since("1.4.0") (
174174
175175 val (baseLearners, learnerWeights) = GradientBoostedTrees .run(oldDataset, boostingStrategy,
176176 $(seed))
177- val m = new GBTClassificationModel (uid, baseLearners, learnerWeights, numFeatures, numClasses )
177+ val m = new GBTClassificationModel (uid, baseLearners, learnerWeights, numFeatures)
178178 instr.logSuccess(m)
179179 m
180180 }
@@ -209,7 +209,8 @@ class GBTClassificationModel private[ml](
209209 @ Since (" 1.6.0" ) override val uid : String ,
210210 private val _trees : Array [DecisionTreeRegressionModel ],
211211 private val _treeWeights : Array [Double ],
212- @ Since (" 1.6.0" ) override val numFeatures : Int )
212+ @ Since (" 1.6.0" ) override val numFeatures : Int ,
213+ @ Since (" 2.2.0" ) override val numClasses : Int )
213214 extends ProbabilisticClassificationModel [Vector , GBTClassificationModel ]
214215 with GBTClassifierParams with TreeEnsembleModel [DecisionTreeRegressionModel ]
215216 with MLWritable with Serializable {
@@ -218,6 +219,20 @@ class GBTClassificationModel private[ml](
218219 require(_trees.length == _treeWeights.length, " GBTClassificationModel given trees, treeWeights" +
219220 s " of non-matching lengths ( ${_trees.length}, ${_treeWeights.length}, respectively). " )
220221
222+ /**
223+ * Construct a GBTClassificationModel
224+ *
225+ * @param _trees Decision trees in the ensemble.
226+ * @param _treeWeights Weights for the decision trees in the ensemble.
227+ * @param numFeatures The number of features.
228+ */
229+ private [ml] def this (
230+ uid : String ,
231+ _trees : Array [DecisionTreeRegressionModel ],
232+ _treeWeights : Array [Double ],
233+ numFeatures : Int ) =
234+ this (uid, _trees, _treeWeights, numFeatures, 2 )
235+
221236 /**
222237 * Construct a GBTClassificationModel
223238 *
@@ -226,7 +241,7 @@ class GBTClassificationModel private[ml](
226241 */
227242 @ Since (" 1.6.0" )
228243 def this (uid : String , _trees : Array [DecisionTreeRegressionModel ], _treeWeights : Array [Double ]) =
229- this (uid, _trees, _treeWeights, - 1 )
244+ this (uid, _trees, _treeWeights, - 1 , 2 )
230245
231246 @ Since (" 1.4.0" )
232247 override def trees : Array [DecisionTreeRegressionModel ] = _trees
@@ -279,7 +294,7 @@ class GBTClassificationModel private[ml](
279294
280295 @ Since (" 1.4.0" )
281296 override def copy (extra : ParamMap ): GBTClassificationModel = {
282- copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures),
297+ copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures, numClasses ),
283298 extra).setParent(parent)
284299 }
285300
@@ -377,14 +392,15 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
377392 oldModel : OldGBTModel ,
378393 parent : GBTClassifier ,
379394 categoricalFeatures : Map [Int , Int ],
380- numFeatures : Int = - 1 ): GBTClassificationModel = {
395+ numFeatures : Int = - 1 ,
396+ numClasses : Int = 2 ): GBTClassificationModel = {
381397 require(oldModel.algo == OldAlgo .Classification , " Cannot convert GradientBoostedTreesModel" +
382398 s " with algo= ${oldModel.algo} (old API) to GBTClassificationModel (new API). " )
383399 val newTrees = oldModel.trees.map { tree =>
384400 // parent for each tree is null since there is no good way to set this.
385401 DecisionTreeRegressionModel .fromOld(tree, null , categoricalFeatures)
386402 }
387403 val uid = if (parent != null ) parent.uid else Identifiable .randomUID(" gbtc" )
388- new GBTClassificationModel (uid, newTrees, oldModel.treeWeights, numFeatures)
404+ new GBTClassificationModel (uid, newTrees, oldModel.treeWeights, numFeatures, numClasses )
389405 }
390406}
0 commit comments