Skip to content

Commit 1abfee0

Browse files
committed
Updated based on comments from jkbradley
1 parent f2e041d commit 1abfee0

File tree

5 files changed

+15
-37
lines changed

5 files changed

+15
-37
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,6 @@ class DecisionTreeClassificationModel private[ml] (
177177
/**
178178
* Construct a decision tree classification model.
179179
* @param rootNode Root node of tree, with other nodes attached.
180-
* @param numFeatures The number of features.
181-
* @param numClasses The number of classes to predict.
182180
*/
183181
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
184182
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class GBTClassifier @Since("1.4.0") (
170170
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
171171
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
172172
instr.logNumFeatures(numFeatures)
173-
instr.logNumClasses(2)
173+
instr.logNumClasses(numClasses)
174174

175175
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
176176
$(seed))
@@ -209,8 +209,7 @@ class GBTClassificationModel private[ml](
209209
@Since("1.6.0") override val uid: String,
210210
private val _trees: Array[DecisionTreeRegressionModel],
211211
private val _treeWeights: Array[Double],
212-
@Since("1.6.0") override val numFeatures: Int,
213-
@Since("2.2.0") override val numClasses: Int)
212+
@Since("1.6.0") override val numFeatures: Int)
214213
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
215214
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
216215
with MLWritable with Serializable {
@@ -219,20 +218,6 @@ class GBTClassificationModel private[ml](
219218
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
220219
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
221220

222-
/**
223-
* Construct a GBTClassificationModel
224-
*
225-
* @param _trees Decision trees in the ensemble.
226-
* @param _treeWeights Weights for the decision trees in the ensemble.
227-
* @param numFeatures The number of features.
228-
*/
229-
private[ml] def this(
230-
uid: String,
231-
_trees: Array[DecisionTreeRegressionModel],
232-
_treeWeights: Array[Double],
233-
numFeatures: Int) =
234-
this(uid, _trees, _treeWeights, numFeatures, 2)
235-
236221
/**
237222
* Construct a GBTClassificationModel
238223
*
@@ -241,7 +226,7 @@ class GBTClassificationModel private[ml](
241226
*/
242227
@Since("1.6.0")
243228
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
244-
this(uid, _trees, _treeWeights, -1, 2)
229+
this(uid, _trees, _treeWeights, -1)
245230

246231
@Since("1.4.0")
247232
override def trees: Array[DecisionTreeRegressionModel] = _trees
@@ -294,7 +279,7 @@ class GBTClassificationModel private[ml](
294279

295280
@Since("1.4.0")
296281
override def copy(extra: ParamMap): GBTClassificationModel = {
297-
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
282+
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
298283
extra).setParent(parent)
299284
}
300285

@@ -339,7 +324,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
339324

340325
private val numFeaturesKey: String = "numFeatures"
341326
private val numTreesKey: String = "numTrees"
342-
private val numClassesKey: String = "numClasses"
343327

344328
@Since("2.0.0")
345329
override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
@@ -354,8 +338,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
354338

355339
val extraMetadata: JObject = Map(
356340
numFeaturesKey -> instance.numFeatures,
357-
numTreesKey -> instance.getNumTrees,
358-
numClassesKey -> instance.numClasses)
341+
numTreesKey -> instance.getNumTrees)
359342
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
360343
}
361344
}
@@ -372,7 +355,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
372355
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
373356
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
374357
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
375-
val numClasses = (metadata.metadata \ numClassesKey).extract[Int]
376358

377359
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
378360
case (treeMetadata, root) =>
@@ -384,7 +366,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
384366
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
385367
s" trees based on metadata but found ${trees.length} trees.")
386368
val model = new GBTClassificationModel(metadata.uid,
387-
trees, treeWeights, numFeatures, numClasses)
369+
trees, treeWeights, numFeatures)
388370
DefaultParamsReader.getAndSetParams(model, metadata)
389371
model
390372
}
@@ -395,15 +377,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
395377
oldModel: OldGBTModel,
396378
parent: GBTClassifier,
397379
categoricalFeatures: Map[Int, Int],
398-
numFeatures: Int = -1,
399-
numClasses: Int = 2): GBTClassificationModel = {
380+
numFeatures: Int = -1): GBTClassificationModel = {
400381
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
401382
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
402383
val newTrees = oldModel.trees.map { tree =>
403384
// parent for each tree is null since there is no good way to set this.
404385
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
405386
}
406387
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
407-
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
388+
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
408389
}
409390
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ object LogLoss extends ClassificationLoss {
5252
2.0 * MLUtils.log1pExp(-margin)
5353
}
5454

55-
override private[spark] def computeProbability(prediction: Double): Double = {
56-
// The probability can be calculated as:
57-
// p+(x) = 1 / (1 + e^(-2 * F(x)))
58-
1.0 / (1.0 + math.exp(-2.0 * prediction))
55+
/**
56+
* Returns the estimated probability of a label of 1.0.
57+
*/
58+
override private[spark] def computeProbability(margin: Double): Double = {
59+
1.0 / (1.0 + math.exp(-2.0 * margin))
5960
}
6061
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ trait Loss extends Serializable {
7070
private[spark] trait ClassificationLoss extends Loss {
7171
/**
7272
* Computes the class probability given the margin.
73-
* @param prediction The margin.
74-
* @return The class probability from the margin.
7573
*/
76-
private[spark] def computeProbability(prediction: Double): Double
74+
private[spark] def computeProbability(margin: Double): Double
7775
}

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
115115
val gbt2 = new GBTClassifier
116116
val threshold = Array(0.3, 0.7)
117117
gbt2.setThresholds(threshold)
118-
assert(gbt2.getThresholds.zip(threshold).forall { case(t1, t2) => t1 === t2 })
118+
assert(gbt2.getThresholds === threshold)
119119
}
120120

121121
test("thresholds prediction") {

0 commit comments

Comments
 (0)