Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

// Find best split for all nodes at a level.
timer.start("findBestSplits")
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
DecisionTree.findBestSplits(treeInput, parentImpurities,
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
Expand All @@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.start("extractNodeInfo")
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val predict = nodeSplitStats._3.predict
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
timer.stop("extractNodeInfo")
Expand Down Expand Up @@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
Expand All @@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
Expand Down Expand Up @@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {

/*
* The high-level descriptions of the best split optimizations are noted here.
Expand Down Expand Up @@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {

// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
// Iterating over all nodes at this level
var nodeIndex = 0
while (nodeIndex < numNodes) {
Expand Down Expand Up @@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
topImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata): InformationGainStats = {

val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count

val totalCount = leftCount + rightCount
if (totalCount == 0) {
// Return arbitrary prediction.
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return InformationGainStats.invalidInformationGainStats
}

val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val totalCount = leftCount + rightCount

// impurity of parent node
val impurity = if (level > 0) {
topImpurity
} else {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
parentNodeAgg.calculate()
}

val predict = parentNodeAgg.predict
val prob = parentNodeAgg.prob(predict)

val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

Expand All @@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
return InformationGainStats.invalidInformationGainStats
}

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
}

/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a node
* @return predict value for current node
*/
private def calculatePredict(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): Predict = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = parentNodeAgg.predict
val prob = parentNodeAgg.prob(predict)

new Predict(predict, prob)
}

/**
Expand All @@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata,
splits: Array[Array[Split]]): (Split, InformationGainStats) = {
splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {

logDebug("node impurity = " + nodeImpurity)

// calculate predict only once
var predict: Option[Predict] = None

// For each (feature, split), calculate the gain, and select the best (feature, split).
Range(0, metadata.numFeatures).map { featureIndex =>
val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
val numSplits = metadata.numSplits(featureIndex)
if (metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
Expand All @@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
Expand All @@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
Expand Down Expand Up @@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
Expand All @@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

require(predict.isDefined, "must calculate predict for each node")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use assert instead of require. The latter throws IllegalArgumentException, which doesn't apply here. (not necessary to update)


(bestSplit, bestSplitStats, predict.get)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* @param minInstancesPerNode Minimum number of instances each child must have after split.
* Default value is 1. If a split cause left or right child
* to have less than minInstancesPerNode,
* this split will not be considered as a valid split.
* @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
* If a split has less information gain than minInfoGain,
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
*/
Expand All @@ -61,6 +68,8 @@ class Strategy (
val maxBins: Int = 32,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val minInstancesPerNode: Int = 1,
val minInfoGain: Double = 0.0,
val maxMemoryInMB: Int = 256) extends Serializable {

if (algo == Classification) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {
val quantileStrategy: QuantileStrategy,
val minInstancesPerNode: Int,
val minInfoGain: Double) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

Expand Down Expand Up @@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
strategy.impurity, strategy.quantileCalculationStrategy,
strategy.minInstancesPerNode, strategy.minInfoGain)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double,
val prob: Double = 0.0) extends Serializable {
val rightImpurity: Double) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
}


private[tree] object InformationGainStats {
/**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a package private class, it is not necessary to mark DeveloperApi.

private[tree] class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space before {


override def toString = {
"predict = %f, prob = %f".format(predict, prob)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
import org.apache.spark.mllib.tree.configuration.FeatureType
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType

/**
* :: DeveloperApi ::
Expand Down
Loading