Skip to content

Commit 818de81

Browse files
committed
Fixing build issues - need to keep numClasses in model
1 parent 1abfee0 commit 818de81

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class GBTClassifier @Since("1.4.0") (
174174

175175
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
176176
$(seed))
177-
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures, numClasses)
177+
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
178178
instr.logSuccess(m)
179179
m
180180
}
@@ -209,7 +209,8 @@ class GBTClassificationModel private[ml](
209209
@Since("1.6.0") override val uid: String,
210210
private val _trees: Array[DecisionTreeRegressionModel],
211211
private val _treeWeights: Array[Double],
212-
@Since("1.6.0") override val numFeatures: Int)
212+
@Since("1.6.0") override val numFeatures: Int,
213+
@Since("2.2.0") override val numClasses: Int)
213214
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
214215
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
215216
with MLWritable with Serializable {
@@ -218,6 +219,20 @@ class GBTClassificationModel private[ml](
218219
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
219220
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
220221

222+
/**
223+
* Construct a GBTClassificationModel
224+
*
225+
* @param _trees Decision trees in the ensemble.
226+
* @param _treeWeights Weights for the decision trees in the ensemble.
227+
* @param numFeatures The number of features.
228+
*/
229+
private[ml] def this(
230+
uid: String,
231+
_trees: Array[DecisionTreeRegressionModel],
232+
_treeWeights: Array[Double],
233+
numFeatures: Int) =
234+
this(uid, _trees, _treeWeights, numFeatures, 2)
235+
221236
/**
222237
* Construct a GBTClassificationModel
223238
*
@@ -226,7 +241,7 @@ class GBTClassificationModel private[ml](
226241
*/
227242
@Since("1.6.0")
228243
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
229-
this(uid, _trees, _treeWeights, -1)
244+
this(uid, _trees, _treeWeights, -1, 2)
230245

231246
@Since("1.4.0")
232247
override def trees: Array[DecisionTreeRegressionModel] = _trees
@@ -279,7 +294,7 @@ class GBTClassificationModel private[ml](
279294

280295
@Since("1.4.0")
281296
override def copy(extra: ParamMap): GBTClassificationModel = {
282-
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
297+
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
283298
extra).setParent(parent)
284299
}
285300

@@ -377,14 +392,15 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
377392
oldModel: OldGBTModel,
378393
parent: GBTClassifier,
379394
categoricalFeatures: Map[Int, Int],
380-
numFeatures: Int = -1): GBTClassificationModel = {
395+
numFeatures: Int = -1,
396+
numClasses: Int = 2): GBTClassificationModel = {
381397
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
382398
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
383399
val newTrees = oldModel.trees.map { tree =>
384400
// parent for each tree is null since there is no good way to set this.
385401
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
386402
}
387403
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
388-
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
404+
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
389405
}
390406
}

0 commit comments

Comments
 (0)