Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 59 additions & 50 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* Extract the decision tree node information for the given tree level and node index
*/
private def extractNodeInfo(
nodeSplitStats: (Split, InformationGainStats),
nodeSplitStats: (Split, InformationGainStats, Predict),
level: Int,
index: Int,
nodes: Array[Node]): Unit = {
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val predict = nodeSplitStats._3
val nodeIndex = (1 << level) - 1 + index
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
val node = new Node(nodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
}
Expand All @@ -207,7 +208,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
level: Int,
index: Int,
maxDepth: Int,
nodeSplitStats: (Split, InformationGainStats),
nodeSplitStats: (Split, InformationGainStats, Predict),
parentImpurities: Array[Double]): Unit = {

if (level >= maxDepth) {
Expand Down Expand Up @@ -450,7 +451,7 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
Expand All @@ -459,7 +460,7 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
Expand Down Expand Up @@ -497,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {

/*
* The high-level descriptions of the best split optimizations are noted here.
Expand Down Expand Up @@ -599,14 +600,6 @@ object DecisionTree extends Serializable with Logging {
}
}

def nodeIndexToLevel(idx: Int): Int = {
if (idx == 0) {
0
} else {
math.floor(math.log(idx) / math.log(2)).toInt
}
}

// Used for treePointToNodeIndex
val levelOffset = (1 << level) - 1

Expand Down Expand Up @@ -865,34 +858,9 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftTotalCount + rightTotalCount
if (totalCount == 0) {
// Return arbitrary prediction.
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
}

// Sum of count for each label
val leftrightNodeAgg: Array[Double] =
leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
leftCount + rightCount
}

def indexOfLargestArrayElement(array: Array[Double]): Int = {
val result = array.foldLeft(-1, Double.MinValue, 0) {
case ((maxIndex, maxValue, currentIndex), currentValue) =>
if (currentValue > maxValue) {
(currentIndex, currentValue, currentIndex + 1)
} else {
(maxIndex, maxValue, currentIndex + 1)
}
}
if (result._1 < 0) {
throw new RuntimeException("DecisionTree internal error:" +
" calculateGainForSplit failed in indexOfLargestArrayElement")
}
result._1
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity)
}

val predict = indexOfLargestArrayElement(leftrightNodeAgg)
val prob = leftrightNodeAgg(predict) / totalCount

val leftImpurity = if (leftTotalCount == 0) {
topImpurity
} else {
Expand All @@ -909,7 +877,7 @@ object DecisionTree extends Serializable with Logging {

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)

} else {
// Regression
Expand All @@ -935,12 +903,11 @@ object DecisionTree extends Serializable with Logging {
}

if (leftCount == 0) {
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
rightSum / rightCount)
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity)
}
if (rightCount == 0) {
return new InformationGainStats(0, topImpurity, topImpurity,
Double.MinValue, leftSum / leftCount)
Double.MinValue)
}

val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
Expand All @@ -951,8 +918,7 @@ object DecisionTree extends Serializable with Logging {

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

val predict = (leftSum + rightSum) / (leftCount + rightCount)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
}
}

Expand Down Expand Up @@ -1162,6 +1128,46 @@ object DecisionTree extends Serializable with Logging {
}
}

def calculatePredict(leftNodeAgg: Array[Double], rightNodeAgg: Array[Double]) = {
if (metadata.isClassification) {
val totalCount = leftNodeAgg.sum + rightNodeAgg.sum
val leftrightNodeAgg: Array[Double] =
leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
leftCount + rightCount
}

def indexOfLargestArrayElement(array: Array[Double]): Int = {
val result = array.foldLeft(-1, Double.MinValue, 0) {
case ((maxIndex, maxValue, currentIndex), currentValue) =>
if (currentValue > maxValue) {
(currentIndex, currentValue, currentIndex + 1)
} else {
(maxIndex, maxValue, currentIndex + 1)
}
}
if (result._1 < 0) {
throw new RuntimeException("DecisionTree internal error:" +
" calculateGainForSplit failed in indexOfLargestArrayElement")
}
result._1
}

val predict = indexOfLargestArrayElement(leftrightNodeAgg)
val prob = leftrightNodeAgg(predict) / totalCount

new Predict(predict, prob)
} else {
val leftCount = leftNodeAgg(0)
val leftSum = leftNodeAgg(1)

val rightCount = rightNodeAgg(0)
val rightSum = rightNodeAgg(1)

val predict = (leftSum + rightSum) / (leftCount + rightCount)
new Predict(predict)
}
}

/**
* Find the best split for a node.
* @param binData Bin data slice for this node, given by getBinDataForNode.
Expand All @@ -1170,21 +1176,24 @@ object DecisionTree extends Serializable with Logging {
*/
def binsToBestSplit(
binData: Array[Double],
nodeImpurity: Double): (Split, InformationGainStats) = {
nodeImpurity: Double): (Split, InformationGainStats, Predict) = {

logDebug("node impurity = " + nodeImpurity)

// Extract left right node aggregates.
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)

// Calculate prediction value for current node.
val predict = calculatePredict(leftNodeAgg(0)(0), rightNodeAgg(0)(0))

// Calculate gains for all splits.
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)

val (bestFeatureIndex, bestSplitIndex, gainStats) = {
// Initialize with infeasible values.
var bestFeatureIndex = Int.MinValue
var bestSplitIndex = Int.MinValue
var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
// Iterate over features.
var featureIndex = 0
while (featureIndex < numFeatures) {
Expand All @@ -1208,7 +1217,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex))
logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))

(splits(bestFeatureIndex)(bestSplitIndex), gainStats)
(splits(bestFeatureIndex)(bestSplitIndex), gainStats, predict)
}

/**
Expand Down Expand Up @@ -1243,7 +1252,7 @@ object DecisionTree extends Serializable with Logging {

// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
// Iterating over all nodes at this level
var node = 0
while (node < numNodes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,16 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double,
val prob: Double = 0.0) extends Serializable {
val rightImpurity: Double) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{

override def toString() = {
"predict = %f, prob = %f".format(predict, prob)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)

val stats = bestSplits(0)._2
val predict = bestSplits(0)._3
assert(stats.gain > 0)
assert(stats.predict === 1)
assert(stats.prob === 0.6)
assert(predict.predict === 1)
assert(predict.prob === 0.6)
assert(stats.impurity > 0.2)
}

Expand All @@ -475,8 +476,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)

val stats = bestSplits(0)._2
val predict = bestSplits(0)._3
assert(stats.gain > 0)
assert(stats.predict === 0.6)
assert(predict.predict === 0.6)
assert(stats.impurity > 0.2)
}

Expand Down Expand Up @@ -543,7 +545,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
assert(bestSplits(0)._3.predict === 1)
}

test("stump with fixed label 0 for Entropy") {
Expand All @@ -568,7 +570,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 0)
assert(bestSplits(0)._3.predict === 0)
}

test("stump with fixed label 1 for Entropy") {
Expand All @@ -593,7 +595,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
assert(bestSplits(0)._3.predict === 1)
}

test("second level node building with/without groups") {
Expand Down Expand Up @@ -644,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
}
}

Expand Down Expand Up @@ -900,4 +902,4 @@ object DecisionTreeSuite {
}


}
}