Skip to content

Commit b8634df

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-3160] [SPARK-3494] [mllib] DecisionTree: eliminate pre-allocated nodes, parentImpurities arrays. Memory calc bug fix.
This PR includes some code simplifications and re-organization which will be helpful for implementing random forests. The main changes are that the nodes and parentImpurities arrays are no longer pre-allocated in the main train() method. Also added 2 bug fixes: * maxMemoryUsage calculation * over-allocation of space for bins in DTStatsAggregator for unordered features. Relation to RFs: * Since RFs will be deeper and will therefore be more likely sparse (not full trees), it could be a cost savings to avoid pre-allocating a full tree. * The associated re-organization also reduces bookkeeping, which will make RFs easier to implement. * The return code doneTraining may be generalized to include cases such as nodes ready for local training. Details: No longer pre-allocate parentImpurities array in main train() method. * parentImpurities values are now stored in individual nodes (in Node.stats.impurity). * These were not really needed. They were used in calculateGainForSplit(), but they can be calculated anyways using parentNodeAgg. No longer using Node.build since tree structure is constructed on-the-fly. * Did not eliminate since it is public (Developer) API. Marked as deprecated. Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. * Moved tree construction from main train() method into findBestSplitsPerGroup() since there is no need to keep the (split, gain) array for an entire level of nodes. Only one element of that array is needed at a time, so we do not the array. findBestSplits() now returns 2 items: * rootNode (newly created root node on first iteration, same root node on later iterations) * doneTraining (indicating if all nodes at that level were leafs) Updated DecisionTreeSuite. Notes: * Improved test "Second level node building with vs. without groups" ** generateOrderedLabeledPoints() modified so that it really does require 2 levels of internal nodes. * Related update: Added Node.deepCopy (private[tree]), used for test suite CC: mengxr Author: Joseph K. Bradley <[email protected]> Closes #2341 from jkbradley/dt-spark-3160 and squashes the following commits: 07dd1ee [Joseph K. Bradley] Fixed overflow bug with computing maxMemoryUsage in DecisionTree. Also fixed bug with over-allocating space in DTStatsAggregator for unordered features. debe072 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code:
1 parent 42904b8 commit b8634df

File tree

7 files changed

+268
-256
lines changed

7 files changed

+268
-256
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 80 additions & 111 deletions
Large diffs are not rendered by default.

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class Strategy (
7575
if (algo == Classification) {
7676
require(numClassesForClassification >= 2)
7777
}
78+
require(minInstancesPerNode >= 1,
79+
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
80+
7881
val isMulticlassClassification =
7982
algo == Classification && numClassesForClassification > 2
8083
val isMulticlassWithCategoricalFeatures

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
6565
* Offset for each feature for calculating indices into the [[allStats]] array.
6666
*/
6767
private val featureOffsets: Array[Int] = {
68-
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
69-
if (isUnordered(featureIndex)) {
70-
total + 2 * numBins(featureIndex)
71-
} else {
72-
total + numBins(featureIndex)
73-
}
74-
}
75-
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
68+
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
7669
}
7770

7871
/**
@@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
149142
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
150143
s" but was called for ordered feature $featureIndex.")
151144
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
152-
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
145+
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
153146
}
154147

155148
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
4646
val numBins: Array[Int],
4747
val impurity: Impurity,
4848
val quantileStrategy: QuantileStrategy,
49+
val maxDepth: Int,
4950
val minInstancesPerNode: Int,
5051
val minInfoGain: Double) extends Serializable {
5152

@@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {
129130

130131
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
131132
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
132-
strategy.impurity, strategy.quantileCalculationStrategy,
133+
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
133134
strategy.minInstancesPerNode, strategy.minInfoGain)
134135
}
135136

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
4646
* Predict values for the given data set using the model trained.
4747
*
4848
* @param features RDD representing data points to be predicted
49-
* @return RDD[Int] where each entry contains the corresponding prediction
49+
* @return RDD of predictions for each of the given data points
5050
*/
5151
def predict(features: RDD[Vector]): RDD[Double] = {
5252
features.map(x => predict(x))

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class Node (
5555
* build the left node and right nodes if not leaf
5656
* @param nodes array of nodes
5757
*/
58+
@deprecated("build should no longer be used since trees are constructed on-the-fly in training",
59+
"1.2.0")
5860
def build(nodes: Array[Node]): Unit = {
5961
logDebug("building node " + id + " at level " + Node.indexToLevel(id))
6062
logDebug("id = " + id + ", split = " + split)
@@ -93,6 +95,23 @@ class Node (
9395
}
9496
}
9597

98+
/**
99+
* Returns a deep copy of the subtree rooted at this node.
100+
*/
101+
private[tree] def deepCopy(): Node = {
102+
val leftNodeCopy = if (leftNode.isEmpty) {
103+
None
104+
} else {
105+
Some(leftNode.get.deepCopy())
106+
}
107+
val rightNodeCopy = if (rightNode.isEmpty) {
108+
None
109+
} else {
110+
Some(rightNode.get.deepCopy())
111+
}
112+
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
113+
}
114+
96115
/**
97116
* Get the number of nodes in tree below this node, including leaf nodes.
98117
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
@@ -190,4 +209,22 @@ private[tree] object Node {
190209
*/
191210
def startIndexInLevel(level: Int): Int = 1 << level
192211

212+
/**
213+
* Traces down from a root node to get the node with the given node index.
214+
* This assumes the node exists.
215+
*/
216+
def getNode(nodeIndex: Int, rootNode: Node): Node = {
217+
var tmpNode: Node = rootNode
218+
var levelsToGo = indexToLevel(nodeIndex)
219+
while (levelsToGo > 0) {
220+
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
221+
tmpNode = tmpNode.leftNode.get
222+
} else {
223+
tmpNode = tmpNode.rightNode.get
224+
}
225+
levelsToGo -= 1
226+
}
227+
tmpNode
228+
}
229+
193230
}

0 commit comments

Comments
 (0)