File tree Expand file tree Collapse file tree 2 files changed +5
-12
lines changed
mllib/src/main/scala/org/apache/spark/mllib/tree Expand file tree Collapse file tree 2 files changed +5
-12
lines changed Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
9191 // Calculate level for single group construction
9292
9393 // Max memory usage for aggregates
94- val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
94+ val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
9595 logDebug(" max memory usage for aggregates = " + maxMemoryUsage + " bytes." )
9696 // TODO: Calculate memory usage more precisely.
9797 val numElementsPerNode = DecisionTree .getElementsPerNode(metadata)
@@ -906,8 +906,8 @@ object DecisionTree extends Serializable with Logging {
906906 /**
907907 * Get the number of values to be stored per node in the bin aggregates.
908908 */
909- private def getElementsPerNode (metadata : DecisionTreeMetadata ): Int = {
910- val totalBins = metadata.numBins.sum
909+ private def getElementsPerNode (metadata : DecisionTreeMetadata ): Long = {
910+ val totalBins = metadata.numBins.map(_.toLong). sum
911911 if (metadata.isClassification) {
912912 metadata.numClasses * totalBins
913913 } else {
Original file line number Diff line number Diff line change @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
6565 * Offset for each feature for calculating indices into the [[allStats ]] array.
6666 */
6767 private val featureOffsets : Array [Int ] = {
68- def featureOffsetsCalc (total : Int , featureIndex : Int ): Int = {
69- if (isUnordered(featureIndex)) {
70- total + 2 * numBins(featureIndex)
71- } else {
72- total + numBins(featureIndex)
73- }
74- }
75- Range (0 , numFeatures).scanLeft(0 )(featureOffsetsCalc).map(statsSize * _).toArray
68+ numBins.scanLeft(0 )((total, nBins) => total + statsSize * nBins)
7669 }
7770
7871 /**
@@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
149142 s " DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only, " +
150143 s " but was called for ordered feature $featureIndex. " )
151144 val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
152- (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
145+ (baseOffset, baseOffset + ( numBins(featureIndex) >> 1 ) * statsSize)
153146 }
154147
155148 /**
You can’t perform that action at this time.
0 commit comments