Skip to content

Commit 8a758db

Browse files
committed
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
2 parents 5fe44ed + 2283df8 commit 8a758db

File tree

7 files changed

+188
-80
lines changed

7 files changed

+188
-80
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scopt.OptionParser
2121

2222
import org.apache.spark.{SparkConf, SparkContext}
2323
import org.apache.spark.SparkContext._
24-
import org.apache.spark.mllib.linalg.Vector
2524
import org.apache.spark.mllib.regression.LabeledPoint
2625
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
2726
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
@@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD
3635
* ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
3736
* }}}
3837
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
38+
*
39+
* Note: This script treats all features as real-valued (not categorical).
40+
* To include categorical features, modify categoricalFeaturesInfo.
3941
*/
4042
object DecisionTreeRunner {
4143

@@ -48,11 +50,13 @@ object DecisionTreeRunner {
4850

4951
case class Params(
5052
input: String = null,
53+
dataFormat: String = null,
5154
algo: Algo = Classification,
5255
numClassesForClassification: Int = 2,
53-
maxDepth: Int = 5,
56+
maxDepth: Int = 4,
5457
impurity: ImpurityType = Gini,
55-
maxBins: Int = 100)
58+
maxBins: Int = 100,
59+
fracTest: Double = 0.2)
5660

5761
def main(args: Array[String]) {
5862
val defaultParams = Params()
@@ -69,25 +73,32 @@ object DecisionTreeRunner {
6973
opt[Int]("maxDepth")
7074
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
7175
.action((x, c) => c.copy(maxDepth = x))
72-
opt[Int]("numClassesForClassification")
73-
.text(s"number of classes for classification, "
74-
+ s"default: ${defaultParams.numClassesForClassification}")
75-
.action((x, c) => c.copy(numClassesForClassification = x))
7676
opt[Int]("maxBins")
7777
.text(s"max number of bins, default: ${defaultParams.maxBins}")
7878
.action((x, c) => c.copy(maxBins = x))
79+
opt[Double]("fracTest")
80+
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
81+
.action((x, c) => c.copy(fracTest = x))
7982
arg[String]("<input>")
8083
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
8184
.required()
8285
.action((x, c) => c.copy(input = x))
86+
arg[String]("<dataFormat>")
87+
.text("data format: dense/libsvm")
88+
.required()
89+
.action((x, c) => c.copy(dataFormat = x))
8390
checkConfig { params =>
84-
if (params.algo == Classification &&
85-
(params.impurity == Gini || params.impurity == Entropy)) {
86-
success
87-
} else if (params.algo == Regression && params.impurity == Variance) {
88-
success
91+
if (params.fracTest < 0 || params.fracTest > 1) {
92+
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
8993
} else {
90-
failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
94+
if (params.algo == Classification &&
95+
(params.impurity == Gini || params.impurity == Entropy)) {
96+
success
97+
} else if (params.algo == Regression && params.impurity == Variance) {
98+
success
99+
} else {
100+
failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
101+
}
91102
}
92103
}
93104
}
@@ -100,16 +111,57 @@ object DecisionTreeRunner {
100111
}
101112

102113
def run(params: Params) {
114+
103115
val conf = new SparkConf().setAppName("DecisionTreeRunner")
104116
val sc = new SparkContext(conf)
105117

106118
// Load training data and cache it.
107-
val examples = MLUtils.loadLabeledPoints(sc, params.input).cache()
119+
val origExamples = params.dataFormat match {
120+
case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
121+
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache()
122+
}
123+
// For classification, re-index classes if needed.
124+
val (examples, numClasses) = params.algo match {
125+
case Classification => {
126+
// classCounts: class --> # examples in class
127+
val classCounts = origExamples.map(_.label).countByValue
128+
val numClasses = classCounts.size
129+
// classIndex: class --> index in 0,...,numClasses-1
130+
val classIndex = {
131+
if (classCounts.keySet != Set[Double](0.0, 1.0)) {
132+
classCounts.keys.toList.sorted.zipWithIndex.toMap
133+
} else {
134+
Map[Double, Int]()
135+
}
136+
}
137+
val examples = {
138+
if (classIndex.isEmpty) {
139+
origExamples
140+
} else {
141+
origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features))
142+
}
143+
}
144+
println(s"numClasses = $numClasses.")
145+
println(s"Per-class example fractions, counts:")
146+
println(s"Class\tFrac\tCount")
147+
classCounts.keys.toList.sorted.foreach(c => {
148+
val frac = classCounts(c) / (0.0 + examples.count())
149+
println(s"$c\t$frac\t${classCounts(c)}")
150+
})
151+
(examples, numClasses)
152+
}
153+
case Regression => {
154+
(origExamples, 0)
155+
}
156+
case _ => {
157+
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
158+
}
159+
}
108160

109-
val splits = examples.randomSplit(Array(0.8, 0.2))
161+
// Split into training, test.
162+
val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
110163
val training = splits(0).cache()
111164
val test = splits(1).cache()
112-
113165
val numTraining = training.count()
114166
val numTest = test.count()
115167

@@ -129,17 +181,19 @@ object DecisionTreeRunner {
129181
impurity = impurityCalculator,
130182
maxDepth = params.maxDepth,
131183
maxBins = params.maxBins,
132-
numClassesForClassification = params.numClassesForClassification)
184+
numClassesForClassification = numClasses)
133185
val model = DecisionTree.train(training, strategy)
134186

187+
println(model)
188+
135189
if (params.algo == Classification) {
136190
val accuracy = accuracyScore(model, test)
137-
println(s"Test accuracy = $accuracy.")
191+
println(s"Test accuracy = $accuracy")
138192
}
139193

140194
if (params.algo == Regression) {
141195
val mse = meanSquaredError(model, test)
142-
println(s"Test mean squared error = $mse.")
196+
println(s"Test mean squared error = $mse")
143197
}
144198

145199
sc.stop()

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,12 @@ object DecisionTree extends Serializable with Logging {
598598
// Find feature bins for all nodes at a level.
599599
val binMappedRDD = input.map(x => findBinsForLevel(x))
600600

601-
def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
602-
label: Double, featureIndex: Int) = {
603-
601+
def updateBinForOrderedFeature(
602+
arr: Array[Double],
603+
agg: Array[Double],
604+
nodeIndex: Int,
605+
label: Double,
606+
featureIndex: Int) = {
604607
// Find the bin index for this feature.
605608
val arrShift = 1 + numFeatures * nodeIndex
606609
val arrIndex = arrShift + featureIndex
@@ -612,27 +615,31 @@ object DecisionTree extends Serializable with Logging {
612615
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
613616
}
614617

615-
def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
616-
label: Double, agg: Array[Double], rightChildShift: Int) = {
618+
def updateBinForUnorderedFeature(
619+
nodeIndex: Int,
620+
featureIndex: Int,
621+
arr: Array[Double],
622+
label: Double,
623+
agg: Array[Double],
624+
rightChildShift: Int) = {
617625
// Find the bin index for this feature.
618-
val arrShift = 1 + numFeatures * nodeIndex
619-
val arrIndex = arrShift + featureIndex
626+
val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
627+
val featureValue = arr(arrIndex).toInt
620628
// Update the left or right count for one bin.
621-
val aggShift = numClasses * numBins * numFeatures * nodeIndex
622-
val aggIndex
623-
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
629+
val aggShift =
630+
numClasses * numBins * numFeatures * nodeIndex +
631+
numClasses * numBins * featureIndex +
632+
label.toInt
624633
// Find all matching bins and increment their values
625634
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
626635
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
627636
var binIndex = 0
628637
while (binIndex < numCategoricalBins) {
629-
val labelInt = label.toInt
630-
if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
631-
agg(aggIndex + binIndex)
632-
= agg(aggIndex + binIndex) + 1
638+
val aggIndex = aggShift + binIndex * numClasses
639+
if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
640+
agg(aggIndex) += 1
633641
} else {
634-
agg(rightChildShift + aggIndex + binIndex)
635-
= agg(rightChildShift + aggIndex + binIndex) + 1
642+
agg(rightChildShift + aggIndex) += 1
636643
}
637644
binIndex += 1
638645
}
@@ -815,20 +822,10 @@ object DecisionTree extends Serializable with Logging {
815822
topImpurity: Double): InformationGainStats = {
816823
strategy.algo match {
817824
case Classification =>
818-
var classIndex = 0
819-
val leftCounts: Array[Double] = new Array[Double](numClasses)
820-
val rightCounts: Array[Double] = new Array[Double](numClasses)
821-
var leftTotalCount = 0.0
822-
var rightTotalCount = 0.0
823-
while (classIndex < numClasses) {
824-
val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
825-
val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
826-
leftCounts(classIndex) = leftClassCount
827-
leftTotalCount += leftClassCount
828-
rightCounts(classIndex) = rightClassCount
829-
rightTotalCount += rightClassCount
830-
classIndex += 1
831-
}
825+
val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex)
826+
val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex)
827+
var leftTotalCount = leftCounts.sum
828+
var rightTotalCount = rightCounts.sum
832829

833830
val impurity = {
834831
if (level > 0) {
@@ -845,33 +842,15 @@ object DecisionTree extends Serializable with Logging {
845842
}
846843
}
847844

848-
if (leftTotalCount == 0) {
849-
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
850-
}
851-
if (rightTotalCount == 0) {
852-
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
853-
}
854-
855-
val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
856-
val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
857-
858-
val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount)
859-
val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount)
860-
861-
val gain = {
862-
if (level > 0) {
863-
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
864-
} else {
865-
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
866-
}
867-
}
868-
869845
val totalCount = leftTotalCount + rightTotalCount
846+
if (totalCount == 0) {
847+
// Return arbitrary prediction.
848+
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
849+
}
870850

871851
// Sum of count for each label
872-
val leftRightCounts: Array[Double]
873-
= leftCounts.zip(rightCounts)
874-
.map{case (leftCount, rightCount) => leftCount + rightCount}
852+
val leftRightCounts: Array[Double] = leftCounts.zip(rightCounts).map {
853+
case (leftCount, rightCount) => leftCount + rightCount }
875854

876855
def indexOfLargestArrayElement(array: Array[Double]): Int = {
877856
val result = array.foldLeft(-1, Double.MinValue, 0) {
@@ -885,6 +864,22 @@ object DecisionTree extends Serializable with Logging {
885864
val predict = indexOfLargestArrayElement(leftRightCounts)
886865
val prob = leftRightCounts(predict) / totalCount
887866

867+
val leftImpurity = if (leftTotalCount == 0) {
868+
topImpurity
869+
} else {
870+
strategy.impurity.calculate(leftCounts, leftTotalCount)
871+
}
872+
val rightImpurity = if (rightTotalCount == 0) {
873+
topImpurity
874+
} else {
875+
strategy.impurity.calculate(rightCounts, rightTotalCount)
876+
}
877+
878+
val leftWeight = leftTotalCount / totalCount
879+
val rightWeight = rightTotalCount / totalCount
880+
881+
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
882+
888883
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
889884
case Regression =>
890885
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ object Entropy extends Impurity {
3434
* information calculation for multiclass classification
3535
* @param counts Array[Double] with counts for each label
3636
* @param totalCount sum of counts for all labels
37-
* @return information value
37+
* @return information value, or 0 if totalCount = 0
3838
*/
3939
@DeveloperApi
4040
override def calculate(counts: Array[Double], totalCount: Double): Double = {
41+
if (totalCount == 0) {
42+
return 0
43+
}
4144
val numClasses = counts.length
4245
var impurity = 0.0
4346
var classIndex = 0
@@ -58,6 +61,7 @@ object Entropy extends Impurity {
5861
* @param count number of instances
5962
* @param sum sum of labels
6063
* @param sumSquares summation of squares of the labels
64+
* @return information value, or 0 if count = 0
6165
*/
6266
@DeveloperApi
6367
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ object Gini extends Impurity {
3333
* information calculation for multiclass classification
3434
* @param counts Array[Double] with counts for each label
3535
* @param totalCount sum of counts for all labels
36-
* @return information value
36+
* @return information value, or 0 if totalCount = 0
3737
*/
3838
@DeveloperApi
3939
override def calculate(counts: Array[Double], totalCount: Double): Double = {
40+
if (totalCount == 0) {
41+
return 0
42+
}
4043
val numClasses = counts.length
4144
var impurity = 1.0
4245
var classIndex = 0
@@ -54,6 +57,7 @@ object Gini extends Impurity {
5457
* @param count number of instances
5558
* @param sum sum of labels
5659
* @param sumSquares summation of squares of the labels
60+
* @return information value, or 0 if count = 0
5761
*/
5862
@DeveloperApi
5963
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ trait Impurity extends Serializable {
3131
* information calculation for multiclass classification
3232
* @param counts Array[Double] with counts for each label
3333
* @param totalCount sum of counts for all labels
34-
* @return information value
34+
* @return information value, or 0 if totalCount = 0
3535
*/
3636
@DeveloperApi
3737
def calculate(counts: Array[Double], totalCount: Double): Double
@@ -42,7 +42,7 @@ trait Impurity extends Serializable {
4242
* @param count number of instances
4343
* @param sum sum of labels
4444
* @param sumSquares summation of squares of the labels
45-
* @return information value
45+
* @return information value, or 0 if count = 0
4646
*/
4747
@DeveloperApi
4848
def calculate(count: Double, sum: Double, sumSquares: Double): Double

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object Variance extends Impurity {
3131
* information calculation for multiclass classification
3232
* @param counts Array[Double] with counts for each label
3333
* @param totalCount sum of counts for all labels
34-
* @return information value
34+
* @return information value, or 0 if totalCount = 0
3535
*/
3636
@DeveloperApi
3737
override def calculate(counts: Array[Double], totalCount: Double): Double =
@@ -43,9 +43,13 @@ object Variance extends Impurity {
4343
* @param count number of instances
4444
* @param sum sum of labels
4545
* @param sumSquares summation of squares of the labels
46+
* @return information value, or 0 if count = 0
4647
*/
4748
@DeveloperApi
4849
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
50+
if (count == 0) {
51+
return 0
52+
}
4953
val squaredLoss = sumSquares - (sum * sum) / count
5054
squaredLoss / count
5155
}

0 commit comments

Comments
 (0)