@@ -170,7 +170,7 @@ class GBTClassifier @Since("1.4.0") (
170170 maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
171171 seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
172172 instr.logNumFeatures(numFeatures)
173- instr.logNumClasses(2 )
173+ instr.logNumClasses(numClasses )
174174
175175 val (baseLearners, learnerWeights) = GradientBoostedTrees .run(oldDataset, boostingStrategy,
176176 $(seed))
@@ -209,8 +209,7 @@ class GBTClassificationModel private[ml](
209209 @ Since (" 1.6.0" ) override val uid : String ,
210210 private val _trees : Array [DecisionTreeRegressionModel ],
211211 private val _treeWeights : Array [Double ],
212- @ Since (" 1.6.0" ) override val numFeatures : Int ,
213- @ Since (" 2.2.0" ) override val numClasses : Int )
212+ @ Since (" 1.6.0" ) override val numFeatures : Int )
214213 extends ProbabilisticClassificationModel [Vector , GBTClassificationModel ]
215214 with GBTClassifierParams with TreeEnsembleModel [DecisionTreeRegressionModel ]
216215 with MLWritable with Serializable {
@@ -219,20 +218,6 @@ class GBTClassificationModel private[ml](
219218 require(_trees.length == _treeWeights.length, " GBTClassificationModel given trees, treeWeights" +
220219 s " of non-matching lengths ( ${_trees.length}, ${_treeWeights.length}, respectively). " )
221220
222- /**
223- * Construct a GBTClassificationModel
224- *
225- * @param _trees Decision trees in the ensemble.
226- * @param _treeWeights Weights for the decision trees in the ensemble.
227- * @param numFeatures The number of features.
228- */
229- private [ml] def this (
230- uid : String ,
231- _trees : Array [DecisionTreeRegressionModel ],
232- _treeWeights : Array [Double ],
233- numFeatures : Int ) =
234- this (uid, _trees, _treeWeights, numFeatures, 2 )
235-
236221 /**
237222 * Construct a GBTClassificationModel
238223 *
@@ -241,7 +226,7 @@ class GBTClassificationModel private[ml](
241226 */
242227 @ Since (" 1.6.0" )
243228 def this (uid : String , _trees : Array [DecisionTreeRegressionModel ], _treeWeights : Array [Double ]) =
244- this (uid, _trees, _treeWeights, - 1 , 2 )
229+ this (uid, _trees, _treeWeights, - 1 )
245230
246231 @ Since (" 1.4.0" )
247232 override def trees : Array [DecisionTreeRegressionModel ] = _trees
@@ -294,7 +279,7 @@ class GBTClassificationModel private[ml](
294279
295280 @ Since (" 1.4.0" )
296281 override def copy (extra : ParamMap ): GBTClassificationModel = {
297- copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures, numClasses ),
282+ copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures),
298283 extra).setParent(parent)
299284 }
300285
@@ -339,7 +324,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
339324
340325 private val numFeaturesKey : String = " numFeatures"
341326 private val numTreesKey : String = " numTrees"
342- private val numClassesKey : String = " numClasses"
343327
344328 @ Since (" 2.0.0" )
345329 override def read : MLReader [GBTClassificationModel ] = new GBTClassificationModelReader
@@ -354,8 +338,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
354338
355339 val extraMetadata : JObject = Map (
356340 numFeaturesKey -> instance.numFeatures,
357- numTreesKey -> instance.getNumTrees,
358- numClassesKey -> instance.numClasses)
341+ numTreesKey -> instance.getNumTrees)
359342 EnsembleModelReadWrite .saveImpl(instance, path, sparkSession, extraMetadata)
360343 }
361344 }
@@ -372,7 +355,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
372355 EnsembleModelReadWrite .loadImpl(path, sparkSession, className, treeClassName)
373356 val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int ]
374357 val numTrees = (metadata.metadata \ numTreesKey).extract[Int ]
375- val numClasses = (metadata.metadata \ numClassesKey).extract[Int ]
376358
377359 val trees : Array [DecisionTreeRegressionModel ] = treesData.map {
378360 case (treeMetadata, root) =>
@@ -384,7 +366,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
384366 require(numTrees == trees.length, s " GBTClassificationModel.load expected $numTrees" +
385367 s " trees based on metadata but found ${trees.length} trees. " )
386368 val model = new GBTClassificationModel (metadata.uid,
387- trees, treeWeights, numFeatures, numClasses )
369+ trees, treeWeights, numFeatures)
388370 DefaultParamsReader .getAndSetParams(model, metadata)
389371 model
390372 }
@@ -395,15 +377,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
395377 oldModel : OldGBTModel ,
396378 parent : GBTClassifier ,
397379 categoricalFeatures : Map [Int , Int ],
398- numFeatures : Int = - 1 ,
399- numClasses : Int = 2 ): GBTClassificationModel = {
380+ numFeatures : Int = - 1 ): GBTClassificationModel = {
400381 require(oldModel.algo == OldAlgo .Classification , " Cannot convert GradientBoostedTreesModel" +
401382 s " with algo= ${oldModel.algo} (old API) to GBTClassificationModel (new API). " )
402383 val newTrees = oldModel.trees.map { tree =>
403384 // parent for each tree is null since there is no good way to set this.
404385 DecisionTreeRegressionModel .fromOld(tree, null , categoricalFeatures)
405386 }
406387 val uid = if (parent != null ) parent.uid else Identifiable .randomUID(" gbtc" )
407- new GBTClassificationModel (uid, newTrees, oldModel.treeWeights, numFeatures, numClasses )
388+ new GBTClassificationModel (uid, newTrees, oldModel.treeWeights, numFeatures)
408389 }
409390}
0 commit comments