@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
130130
131131 // Find best split for all nodes at a level.
132132 timer.start(" findBestSplits" )
133- val splitsStatsForLevel : Array [(Split , InformationGainStats )] =
133+ val splitsStatsForLevel : Array [(Split , InformationGainStats , Predict )] =
134134 DecisionTree .findBestSplits(treeInput, parentImpurities,
135135 metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136136 timer.stop(" findBestSplits" )
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
143143 timer.start(" extractNodeInfo" )
144144 val split = nodeSplitStats._1
145145 val stats = nodeSplitStats._2
146+ val predict = nodeSplitStats._3.predict
146147 val isLeaf = (stats.gain <= 0 ) || (level == strategy.maxDepth)
147- val node = new Node (nodeIndex, stats. predict, isLeaf, Some (split), None , None , Some (stats))
148+ val node = new Node (nodeIndex, predict, isLeaf, Some (split), None , None , Some (stats))
148149 logDebug(" Node = " + node)
149150 nodes(nodeIndex) = node
150151 timer.stop(" extractNodeInfo" )
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
425426 splits : Array [Array [Split ]],
426427 bins : Array [Array [Bin ]],
427428 maxLevelForSingleGroup : Int ,
428- timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats )] = {
429+ timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats , Predict )] = {
429430 // split into groups to avoid memory overflow during aggregation
430431 if (level > maxLevelForSingleGroup) {
431432 // When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
434435 // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
435436 val numGroups = 1 << level - maxLevelForSingleGroup
436437 logDebug(" numGroups = " + numGroups)
437- var bestSplits = new Array [(Split , InformationGainStats )](0 )
438+ var bestSplits = new Array [(Split , InformationGainStats , Predict )](0 )
438439 // Iterate over each group of nodes at a level.
439440 var groupIndex = 0
440441 while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
605606 bins : Array [Array [Bin ]],
606607 timer : TimeTracker ,
607608 numGroups : Int = 1 ,
608- groupIndex : Int = 0 ): Array [(Split , InformationGainStats )] = {
609+ groupIndex : Int = 0 ): Array [(Split , InformationGainStats , Predict )] = {
609610
610611 /*
611612 * The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
705706
706707 // Calculate best splits for all nodes at a given level
707708 timer.start(" chooseSplits" )
708- val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
709+ val bestSplits = new Array [(Split , InformationGainStats , Predict )](numNodes)
709710 // Iterating over all nodes at this level
710711 var nodeIndex = 0
711712 while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
734735 topImpurity : Double ,
735736 level : Int ,
736737 metadata : DecisionTreeMetadata ): InformationGainStats = {
737-
738738 val leftCount = leftImpurityCalculator.count
739739 val rightCount = rightImpurityCalculator.count
740740
741- val totalCount = leftCount + rightCount
742- if (totalCount == 0 ) {
743- // Return arbitrary prediction.
744- return new InformationGainStats (0 , topImpurity, topImpurity, topImpurity, 0 )
741+ // If left child or right child doesn't satisfy minimum instances per node,
742+ // then this split is invalid, return invalid information gain stats.
743+ if ((leftCount < metadata.minInstancesPerNode) ||
744+ (rightCount < metadata.minInstancesPerNode)) {
745+ return InformationGainStats .invalidInformationGainStats
745746 }
746747
747- val parentNodeAgg = leftImpurityCalculator.copy
748- parentNodeAgg.add(rightImpurityCalculator)
748+ val totalCount = leftCount + rightCount
749+
749750 // impurity of parent node
750751 val impurity = if (level > 0 ) {
751752 topImpurity
752753 } else {
754+ val parentNodeAgg = leftImpurityCalculator.copy
755+ parentNodeAgg.add(rightImpurityCalculator)
753756 parentNodeAgg.calculate()
754757 }
755758
756- val predict = parentNodeAgg.predict
757- val prob = parentNodeAgg.prob(predict)
758-
759759 val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
760760 val rightImpurity = rightImpurityCalculator.calculate()
761761
@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
764764
765765 val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
766766
767- new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict, prob)
767+ // if information gain doesn't satisfy minimum information gain,
768+ // then this split is invalid, return invalid information gain stats.
769+ if (gain < metadata.minInfoGain) {
770+ return InformationGainStats .invalidInformationGainStats
771+ }
772+
773+ new InformationGainStats (gain, impurity, leftImpurity, rightImpurity)
774+ }
775+
776+ /**
777+ * Calculate predict value for current node, given stats of any split.
778+ * Note that this function is called only once for each node.
779+ * @param leftImpurityCalculator left node aggregates for a split
780+ * @param rightImpurityCalculator right node aggregates for a node
781+ * @return predict value for current node
782+ */
783+ private def calculatePredict (
784+ leftImpurityCalculator : ImpurityCalculator ,
785+ rightImpurityCalculator : ImpurityCalculator ): Predict = {
786+ val parentNodeAgg = leftImpurityCalculator.copy
787+ parentNodeAgg.add(rightImpurityCalculator)
788+ val predict = parentNodeAgg.predict
789+ val prob = parentNodeAgg.prob(predict)
790+
791+ new Predict (predict, prob)
768792 }
769793
770794 /**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
780804 nodeImpurity : Double ,
781805 level : Int ,
782806 metadata : DecisionTreeMetadata ,
783- splits : Array [Array [Split ]]): (Split , InformationGainStats ) = {
807+ splits : Array [Array [Split ]]): (Split , InformationGainStats , Predict ) = {
784808
785809 logDebug(" node impurity = " + nodeImpurity)
786810
811+ // calculate predict only once
812+ var predict : Option [Predict ] = None
813+
787814 // For each (feature, split), calculate the gain, and select the best (feature, split).
788- Range (0 , metadata.numFeatures).map { featureIndex =>
815+ val (bestSplit, bestSplitStats) = Range (0 , metadata.numFeatures).map { featureIndex =>
789816 val numSplits = metadata.numSplits(featureIndex)
790817 if (metadata.isContinuous(featureIndex)) {
791818 // Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
803830 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
804831 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
805832 rightChildStats.subtract(leftChildStats)
833+ predict = Some (predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
806834 val gainStats =
807835 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
808836 (splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
816844 Range (0 , numSplits).map { splitIndex =>
817845 val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
818846 val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
847+ predict = Some (predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
819848 val gainStats =
820849 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
821850 (splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
887916 val rightChildStats =
888917 binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
889918 rightChildStats.subtract(leftChildStats)
919+ predict = Some (predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
890920 val gainStats =
891921 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
892922 (splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
898928 (bestFeatureSplit, bestFeatureGainStats)
899929 }
900930 }.maxBy(_._2.gain)
931+
932+ require(predict.isDefined, " must calculate predict for each node" )
933+
934+ (bestSplit, bestSplitStats, predict.get)
901935 }
902936
903937 /**
0 commit comments