From 0552c7e798f5d62b74511372c0d38e08e50e6bac Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Thu, 28 Aug 2014 16:03:55 +0800 Subject: [PATCH 1/4] separate calculation of predict of node from calculation of info gain of splits --- .../spark/mllib/tree/DecisionTree.scala | 109 ++++++++++-------- .../tree/model/InformationGainStats.scala | 10 +- 2 files changed, 62 insertions(+), 57 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5cdd258f6c20b..b826280f15b73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -187,15 +187,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * Extract the decision tree node information for the given tree level and node index */ private def extractNodeInfo( - nodeSplitStats: (Split, InformationGainStats), + nodeSplitStats: (Split, InformationGainStats, Predict), level: Int, index: Int, nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 + val predict = nodeSplitStats._3 val nodeIndex = (1 << level) - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node } @@ -207,7 +208,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo level: Int, index: Int, maxDepth: Int, - nodeSplitStats: (Split, InformationGainStats), + nodeSplitStats: (Split, InformationGainStats, Predict), parentImpurities: Array[Double]): Unit = { if (level >= maxDepth) { @@ -450,7 +451,7 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, @@ -459,7 +460,7 @@ object DecisionTree extends Serializable with Logging { // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) - var bestSplits = new Array[(Split, InformationGainStats)](0) + var bestSplits = new Array[(Split, InformationGainStats, Predict)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { @@ -497,7 +498,7 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]], timer: TimeTracker, numGroups: Int = 1, - groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { + groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -599,14 +600,6 @@ object DecisionTree extends Serializable with Logging { } } - def nodeIndexToLevel(idx: Int): Int = { - if (idx == 0) { - 0 - } else { - math.floor(math.log(idx) / math.log(2)).toInt - } - } - // Used for treePointToNodeIndex val levelOffset = (1 << level) - 1 @@ -865,34 +858,9 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftTotalCount + rightTotalCount if (totalCount == 0) { // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } - - // Sum of count for each label - val leftrightNodeAgg: Array[Double] = - leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => - leftCount + rightCount - } - - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity) } - val predict = indexOfLargestArrayElement(leftrightNodeAgg) - val prob = leftrightNodeAgg(predict) / totalCount - val leftImpurity = if (leftTotalCount == 0) { topImpurity } else { @@ -909,7 +877,7 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) } else { // Regression @@ -935,12 +903,11 @@ object DecisionTree extends Serializable with Logging { } if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity) } if (rightCount == 0) { return new InformationGainStats(0, topImpurity, topImpurity, - Double.MinValue, leftSum / leftCount) + Double.MinValue) } val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) @@ -951,8 +918,7 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) } } @@ -1162,6 +1128,46 @@ object DecisionTree extends Serializable with Logging { } } + def calculatePredict(leftNodeAgg: Array[Double], rightNodeAgg: Array[Double]) = { + if (metadata.isClassification) { + val totalCount = leftNodeAgg.sum + rightNodeAgg.sum + val leftrightNodeAgg: Array[Double] = + leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => + leftCount + rightCount + } + + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") + } + result._1 + } + + val predict = indexOfLargestArrayElement(leftrightNodeAgg) + val prob = leftrightNodeAgg(predict) / totalCount + + new Predict(predict, prob) + } else { + val leftCount = leftNodeAgg(0) + val leftSum = leftNodeAgg(1) + + val rightCount = rightNodeAgg(0) + val rightSum = rightNodeAgg(1) + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new Predict(predict) + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1170,13 +1176,16 @@ object DecisionTree extends Serializable with Logging { */ def binsToBestSplit( binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { + nodeImpurity: Double): (Split, InformationGainStats, Predict) = { logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + // Calculate prediction value for current node. + val predict = calculatePredict(leftNodeAgg(0)(0), rightNodeAgg(0)(0)) + // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -1184,7 +1193,7 @@ object DecisionTree extends Serializable with Logging { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) // Iterate over features. var featureIndex = 0 while (featureIndex < numFeatures) { @@ -1208,7 +1217,7 @@ object DecisionTree extends Serializable with Logging { logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - (splits(bestFeatureIndex)(bestSplitIndex), gainStats) + (splits(bestFeatureIndex)(bestSplitIndex), gainStats, predict) } /** @@ -1243,7 +1252,7 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes) // Iterating over all nodes at this level var node = 0 while (node < numNodes) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index fb12298e0f5d3..9a009ae39b938 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,20 +26,16 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity - * @param predict predicted value - * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double, - val predict: Double, - val prob: Double = 0.0) extends Serializable { + val rightImpurity: Double) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" + .format(gain, impurity, leftImpurity, rightImpurity) } } From c205eb8775a8dabfd567501972e2c9732c2fe80a Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Thu, 28 Aug 2014 16:05:20 +0800 Subject: [PATCH 2/4] commit Predict.scala --- .../spark/mllib/tree/model/Predict.scala | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala new file mode 100644 index 0000000000000..733cfa8cf9575 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Predicted value for a node + * @param predict predicted value + * @param prob probability of the label (classification only) + */ +@DeveloperApi +class Predict( + val predict: Double, + val prob: Double = 0.0) extends Serializable{ + + override def toString() = { + "predict = %f, prob = %f".format(predict, prob) + } +} From d92b3d47666e1c907222605b873172ef4a2c770c Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Thu, 28 Aug 2014 16:19:59 +0800 Subject: [PATCH 3/4] fix decision tree suite --- .../spark/mllib/tree/DecisionTreeSuite.scala | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2f36fd907772c..d765047ff0251 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -34,9 +34,9 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { def validateClassifier( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { val predictions = input.map(x => model.predict(x.features)) val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label @@ -47,9 +47,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } def validateRegressor( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => val err = prediction - expected.label @@ -446,9 +446,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 + val predict = bestSplits(0)._3 assert(stats.gain > 0) - assert(stats.predict === 1) - assert(stats.prob === 0.6) + assert(predict.predict === 1) + assert(predict.prob === 0.6) assert(stats.impurity > 0.2) } @@ -475,8 +476,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 + val predict = bestSplits(0)._3 assert(stats.gain > 0) - assert(stats.predict === 0.6) + assert(predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -543,7 +545,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + assert(bestSplits(0)._3.predict === 1) } test("stump with fixed label 0 for Entropy") { @@ -568,7 +570,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 0) + assert(bestSplits(0)._3.predict === 0) } test("stump with fixed label 1 for Entropy") { @@ -593,7 +595,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + assert(bestSplits(0)._3.predict === 1) } test("second level node building with/without groups") { @@ -644,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) - assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict) } } @@ -885,7 +887,7 @@ object DecisionTreeSuite { } def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): - Array[LabeledPoint] = { + Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 1000) { @@ -900,4 +902,4 @@ object DecisionTreeSuite { } -} +} \ No newline at end of file From e6af523e56da40b303b64599d70d4d01ba1baa1d Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Thu, 28 Aug 2014 19:16:16 +0800 Subject: [PATCH 4/4] fix indentation --- .../spark/mllib/tree/DecisionTreeSuite.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index d765047ff0251..622bcb9058ddc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -34,9 +34,9 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { def validateClassifier( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { val predictions = input.map(x => model.predict(x.features)) val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label @@ -47,9 +47,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } def validateRegressor( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => val err = prediction - expected.label @@ -887,7 +887,7 @@ object DecisionTreeSuite { } def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): - Array[LabeledPoint] = { + Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 1000) {