Skip to content
Closed
191 changes: 80 additions & 111 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Strategy (
if (algo == Classification) {
require(numClassesForClassification >= 2)
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")

val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
if (isUnordered(featureIndex)) {
total + 2 * numBins(featureIndex)
} else {
total + numBins(featureIndex)
}
}
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}

/**
Expand Down Expand Up @@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
val minInfoGain: Double) extends Serializable {

Expand Down Expand Up @@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
Expand Down
37 changes: 37 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class Node (
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
@deprecated("build should no longer be used since trees are constructed on-the-fly in training",
"1.2.0")
def build(nodes: Array[Node]): Unit = {
logDebug("building node " + id + " at level " + Node.indexToLevel(id))
logDebug("id = " + id + ", split = " + split)
Expand Down Expand Up @@ -93,6 +95,23 @@ class Node (
}
}

/**
* Returns a deep copy of the subtree rooted at this node.
*/
private[tree] def deepCopy(): Node = {
val leftNodeCopy = if (leftNode.isEmpty) {
None
} else {
Some(leftNode.get.deepCopy())
}
val rightNodeCopy = if (rightNode.isEmpty) {
None
} else {
Some(rightNode.get.deepCopy())
}
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}

/**
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
Expand Down Expand Up @@ -190,4 +209,22 @@ private[tree] object Node {
*/
def startIndexInLevel(level: Int): Int = 1 << level

/**
* Traces down from a root node to get the node with the given node index.
* This assumes the node exists.
*/
def getNode(nodeIndex: Int, rootNode: Node): Node = {
var tmpNode: Node = rootNode
var levelsToGo = indexToLevel(nodeIndex)
while (levelsToGo > 0) {
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
tmpNode = tmpNode.leftNode.get
} else {
tmpNode = tmpNode.rightNode.get
}
levelsToGo -= 1
}
tmpNode
}

}
Loading