Skip to content

Commit 07dd1ee

Browse files
committed
Fixed overflow bug with computing maxMemoryUsage in DecisionTree. Also fixed bug with over-allocating space in DTStatsAggregator for unordered features.
1 parent debe072 commit 07dd1ee

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
9191
// Calculate level for single group construction
9292

9393
// Max memory usage for aggregates
94-
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
94+
val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
9595
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
9696
// TODO: Calculate memory usage more precisely.
9797
val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
@@ -906,8 +906,8 @@ object DecisionTree extends Serializable with Logging {
906906
/**
907907
* Get the number of values to be stored per node in the bin aggregates.
908908
*/
909-
private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
910-
val totalBins = metadata.numBins.sum
909+
private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
910+
val totalBins = metadata.numBins.map(_.toLong).sum
911911
if (metadata.isClassification) {
912912
metadata.numClasses * totalBins
913913
} else {

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
6565
* Offset for each feature for calculating indices into the [[allStats]] array.
6666
*/
6767
private val featureOffsets: Array[Int] = {
68-
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
69-
if (isUnordered(featureIndex)) {
70-
total + 2 * numBins(featureIndex)
71-
} else {
72-
total + numBins(featureIndex)
73-
}
74-
}
75-
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
68+
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
7669
}
7770

7871
/**
@@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
149142
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
150143
s" but was called for ordered feature $featureIndex.")
151144
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
152-
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
145+
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
153146
}
154147

155148
/**

0 commit comments

Comments
 (0)