From 64d066b90b152ceb71b185b7e17313486974ae77 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 27 Jun 2016 13:42:44 -0700 Subject: [PATCH 1/8] Add calculateGain method to all Impurity objects --- .../ml/tree/impl/DTStatsAggregator.scala | 21 +++++ .../spark/ml/tree/impl/RandomForest.scala | 15 +++- .../spark/mllib/tree/impurity/Entropy.scala | 78 +++++++++++++++++-- .../spark/mllib/tree/impurity/Gini.scala | 53 +++++++++++++ .../spark/mllib/tree/impurity/Impurity.scala | 8 ++ .../spark/mllib/tree/impurity/Variance.scala | 40 ++++++++++ 6 files changed, 207 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 61091bb803e49..9af0f365743ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -94,6 +94,27 @@ private[spark] class DTStatsAggregator( impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + def calculateGain( + featureOffset: Int, + leftBinIndex: Int, + parentBinIndex: Int): Double = { + val leftChildOffset = featureOffset + leftBinIndex * statsSize + val parentOffset = featureOffset + parentBinIndex * statsSize + val gain = metadata.impurity match { + case Gini => Gini.calculateGain( + allStats, leftChildOffset, parentOffset, statsSize, metadata.minInstancesPerNode, + metadata.minInfoGain) + case Entropy => Entropy.calculateGain( + allStats, leftChildOffset, parentOffset, statsSize, metadata.minInstancesPerNode, + metadata.minInfoGain) + case Variance => Variance.calculateGain( + allStats, leftChildOffset, parentOffset, statsSize, metadata.minInstancesPerNode, + metadata.minInfoGain) + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + gain + } + /** * Get an [[ImpurityCalculator]] for the parent node. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..917b6e99f0368 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -712,7 +712,7 @@ private[spark] object RandomForest extends Logging { splitIndex += 1 } // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = + val (temp, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = @@ -722,6 +722,11 @@ private[spark] object RandomForest extends Logging { leftChildStats, rightChildStats, binAggregates.metadata) (splitIdx, gainAndImpurityStats) }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, maxGain) = + Range(0, numSplits).map { case splitIdx => + val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits) + (splitIdx, gain) + }.maxBy(_._2) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature @@ -794,7 +799,7 @@ private[spark] object RandomForest extends Logging { // lastCategory = index of bin with total aggregates for this (node, feature) val lastCategory = categoriesSortedByCentroid.last._1 // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = + val (temp, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val featureValue = categoriesSortedByCentroid(splitIndex)._1 val leftChildStats = @@ -806,6 +811,12 @@ private[spark] object RandomForest extends Logging { leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, maxGain) = + Range(0, numSplits).map { case splitIdx => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val gain = binAggregates.calculateGain(nodeFeatureOffset, featureValue, lastCategory) + (splitIdx, gain) + }.maxBy(_._2) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 3a731f45d6a07..a8728dd7290ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -27,14 +27,32 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} @Experimental object Entropy extends Impurity { - private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + private[tree] def log2(x: Double): Double = { + if (x == 0) { + return 0.0 + } else { + return scala.math.log(x) / scala.math.log(2) + } + } /** * :: DeveloperApi :: * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value, or 0 if totalCount = 0 + * @return information value, or 0 if totalCount = 0 var leftCount = 0.0 + var totalCount = 0.0 + var i = 0 + while (i < statsSize) { + leftCount += allStats(leftChildOffset + i) + totalCount += allStats(parentOffset + i) + } + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } */ @Since("1.1.0") @DeveloperApi @@ -47,10 +65,8 @@ object Entropy extends Impurity { var classIndex = 0 while (classIndex < numClasses) { val classCount = counts(classIndex) - if (classCount != 0) { - val freq = classCount / totalCount - impurity -= freq * log2(freq) - } + val freq = classCount / totalCount + impurity -= freq * log2(freq) classIndex += 1 } impurity @@ -76,6 +92,56 @@ object Entropy extends Impurity { @Since("1.1.0") def instance: this.type = this + override def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double = { + var leftCount = 0.0 + var totalCount = 0.0 + var i = 0 + while (i < statsSize) { + leftCount += allStats(leftChildOffset + i) + totalCount += allStats(parentOffset + i) + i += 1 + } + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } + + var leftImpurity = 0.0 + var rightImpurity = 0.0 + var parentImpurity = 0.0 + + i = 0 + while (i < statsSize) { + val leftStats = allStats(leftChildOffset + i) + val parentStats = allStats(parentOffset + i) + + val leftFreq = leftStats / leftCount + val rightFreq = (parentStats - leftStats) / rightCount + val parentFreq = parentStats / totalCount + + leftImpurity -= leftFreq * log2(leftFreq) + rightImpurity -= rightFreq * log2(rightFreq) + parentImpurity -= parentFreq * log2(parentFreq) + + i += 1 + } + val leftWeighted = leftCount / totalCount * leftImpurity + val rightWeighted = rightCount / totalCount * rightImpurity + val gain = parentImpurity - leftWeighted - rightWeighted + + if (gain < minInfoGain) { + return Double.MinValue + } + gain + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7730c0a8c1117..5024c8ba48942 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -73,6 +73,59 @@ object Gini extends Impurity { @Since("1.1.0") def instance: this.type = this + override def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double = { + + var leftCount = 0.0 + var totalCount = 0.0 + var i = 0 + while (i < statsSize) { + leftCount += allStats(leftChildOffset + i) + totalCount += allStats(parentOffset + i) + i += 1 + } + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } + + var leftImpurity = 1.0 + var rightImpurity = 1.0 + var parentImpurity = 1.0 + + i = 0 + while (i < statsSize) { + val leftStats = allStats(leftChildOffset + i) + val parentStats = allStats(parentOffset + i) + + val leftFreq = leftStats / leftCount + val rightFreq = (parentStats - leftStats) / rightCount + val parentFreq = parentStats / totalCount + + leftImpurity -= leftFreq * leftFreq + rightImpurity -= rightFreq * rightFreq + parentImpurity -= parentFreq * parentFreq + + i += 1 + } + + val leftWeighted = leftCount / totalCount * leftImpurity + val rightWeighted = rightCount / totalCount * rightImpurity + val gain = parentImpurity - leftWeighted - rightWeighted + + if (gain < minInfoGain) { + return Double.MinValue + } + gain + } + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 65f0163ec6059..ae4c856c5a675 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -52,6 +52,14 @@ trait Impurity extends Serializable { @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double + + protected def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 2423516123b82..89a77163af663 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -64,6 +64,46 @@ object Variance extends Impurity { @Since("1.0.0") def instance: this.type = this + override def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double = { + val leftCount = allStats(leftChildOffset) + val totalCount = allStats(parentOffset) + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } + + val leftSum = allStats(leftChildOffset + 1) + val leftSumSquares = allStats(leftChildOffset + 2) + + val parentSum = allStats(parentOffset + 1) + val parentSumSquares = allStats(parentOffset + 2) + + val rightSum = parentSum - leftSum + val rightSumSquares = parentSumSquares - leftSumSquares + + val parentImpurity = (parentSumSquares - (parentSum * parentSum) / totalCount) / totalCount + val leftImpurity = (leftSumSquares - (leftSum * leftSum) / leftCount) / leftCount + val rightImpurity = (rightSumSquares - (rightSum * rightSum) / rightCount) / leftCount + + val leftWeighted = leftImpurity * leftCount / totalCount + val rightWeighted = rightImpurity * rightCount / totalCount + val gain = parentImpurity - leftWeighted - rightWeighted + + if (gain < minInfoGain) { + return Double.MinValue + } + gain + + } + } /** From f1d8c8950f8adace6ee175cd569b20ed6468bb61 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 27 Jun 2016 14:32:56 -0700 Subject: [PATCH 2/8] Refactor gain calculation for categorical splits --- .../ml/tree/impl/DTStatsAggregator.scala | 25 ++++++++++++++++--- .../spark/ml/tree/impl/RandomForest.scala | 7 +++++- .../spark/mllib/tree/impurity/Entropy.scala | 9 ++++--- .../spark/mllib/tree/impurity/Gini.scala | 9 ++++--- .../spark/mllib/tree/impurity/Impurity.scala | 1 + .../spark/mllib/tree/impurity/Variance.scala | 5 ++-- 6 files changed, 42 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 9af0f365743ad..6e7c475971b60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -102,13 +102,32 @@ private[spark] class DTStatsAggregator( val parentOffset = featureOffset + parentBinIndex * statsSize val gain = metadata.impurity match { case Gini => Gini.calculateGain( - allStats, leftChildOffset, parentOffset, statsSize, metadata.minInstancesPerNode, + allStats, leftChildOffset, allStats, parentOffset, statsSize, + metadata.minInstancesPerNode, metadata.minInfoGain) + case Entropy => Entropy.calculateGain( + allStats, leftChildOffset, allStats, parentOffset, statsSize, + metadata.minInstancesPerNode, metadata.minInfoGain) + case Variance => Variance.calculateGain( + allStats, leftChildOffset, allStats, parentOffset, statsSize, + metadata.minInstancesPerNode, metadata.minInfoGain) + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + gain + } + + def calculateGain( + featureOffset: Int, + leftBinIndex: Int): Double = { + val leftChildOffset = featureOffset + leftBinIndex * statsSize + val gain = metadata.impurity match { + case Gini => Gini.calculateGain( + allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode, metadata.minInfoGain) case Entropy => Entropy.calculateGain( - allStats, leftChildOffset, parentOffset, statsSize, metadata.minInstancesPerNode, + allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode, metadata.minInfoGain) case Variance => Variance.calculateGain( - allStats, leftChildOffset, parentOffset, statsSize, metadata.minInstancesPerNode, + allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode, metadata.minInfoGain) case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 917b6e99f0368..359ec93e03cda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -731,7 +731,7 @@ private[spark] object RandomForest extends Logging { } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = + val (temp, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getParentImpurityCalculator() @@ -740,6 +740,11 @@ private[spark] object RandomForest extends Logging { leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, maxGain) = + Range(0, numSplits).map { case splitIdx => + val gain = binAggregates.calculateGain(leftChildOffset, splitIdx) + (splitIdx, gain) + }.maxBy(_._2) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a8728dd7290ad..081d230d45864 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -95,6 +95,7 @@ object Entropy extends Impurity { override def calculateGain( allStats: Array[Double], leftChildOffset: Int, + parentStats: Array[Double], parentOffset: Int, statsSize: Int, minInstancesPerNode: Int, @@ -104,7 +105,7 @@ object Entropy extends Impurity { var i = 0 while (i < statsSize) { leftCount += allStats(leftChildOffset + i) - totalCount += allStats(parentOffset + i) + totalCount += parentStats(parentOffset + i) i += 1 } val rightCount = totalCount - leftCount @@ -121,11 +122,11 @@ object Entropy extends Impurity { i = 0 while (i < statsSize) { val leftStats = allStats(leftChildOffset + i) - val parentStats = allStats(parentOffset + i) + val totalStats = parentStats(parentOffset + i) val leftFreq = leftStats / leftCount - val rightFreq = (parentStats - leftStats) / rightCount - val parentFreq = parentStats / totalCount + val rightFreq = (totalStats - leftStats) / rightCount + val parentFreq = totalStats / totalCount leftImpurity -= leftFreq * log2(leftFreq) rightImpurity -= rightFreq * log2(rightFreq) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 5024c8ba48942..99bd93c78d1ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -76,6 +76,7 @@ object Gini extends Impurity { override def calculateGain( allStats: Array[Double], leftChildOffset: Int, + parentStats: Array[Double], parentOffset: Int, statsSize: Int, minInstancesPerNode: Int, @@ -86,7 +87,7 @@ object Gini extends Impurity { var i = 0 while (i < statsSize) { leftCount += allStats(leftChildOffset + i) - totalCount += allStats(parentOffset + i) + totalCount += parentStats(parentOffset + i) i += 1 } val rightCount = totalCount - leftCount @@ -103,11 +104,11 @@ object Gini extends Impurity { i = 0 while (i < statsSize) { val leftStats = allStats(leftChildOffset + i) - val parentStats = allStats(parentOffset + i) + val totalStats = parentStats(parentOffset + i) val leftFreq = leftStats / leftCount - val rightFreq = (parentStats - leftStats) / rightCount - val parentFreq = parentStats / totalCount + val rightFreq = (totalStats - leftStats) / rightCount + val parentFreq = totalStats / totalCount leftImpurity -= leftFreq * leftFreq rightImpurity -= rightFreq * rightFreq diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index ae4c856c5a675..045744a2a4367 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -56,6 +56,7 @@ trait Impurity extends Serializable { protected def calculateGain( allStats: Array[Double], leftChildOffset: Int, + parentStats: Array[Double], parentOffset: Int, statsSize: Int, minInstancesPerNode: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 89a77163af663..ead2ce8c327c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -67,6 +67,7 @@ object Variance extends Impurity { override def calculateGain( allStats: Array[Double], leftChildOffset: Int, + parentStats: Array[Double], parentOffset: Int, statsSize: Int, minInstancesPerNode: Int, @@ -83,8 +84,8 @@ object Variance extends Impurity { val leftSum = allStats(leftChildOffset + 1) val leftSumSquares = allStats(leftChildOffset + 2) - val parentSum = allStats(parentOffset + 1) - val parentSumSquares = allStats(parentOffset + 2) + val parentSum = parentStats(parentOffset + 1) + val parentSumSquares = parentStats(parentOffset + 2) val rightSum = parentSum - leftSum val rightSumSquares = parentSumSquares - leftSumSquares From 6e31e3a7b36981c8ccbf867e013363aa6f784e39 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 27 Jun 2016 15:58:10 -0700 Subject: [PATCH 3/8] Remove impurity calculation to outside the for loop --- .../spark/ml/tree/impl/RandomForest.scala | 53 ++++++++----------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 359ec93e03cda..03efd112e7f3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -712,39 +712,33 @@ private[spark] object RandomForest extends Logging { splitIndex += 1 } // Find best split. - val (temp, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) val (bestFeatureSplitIndex, maxGain) = Range(0, numSplits).map { case splitIdx => val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits) (splitIdx, gain) }.maxBy(_._2) + val leftChildStats = binAggregates.getImpurityCalculator( + nodeFeatureOffset, bestFeatureSplitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + rightChildStats.subtract(leftChildStats) + val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (temp, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) val (bestFeatureSplitIndex, maxGain) = Range(0, numSplits).map { case splitIdx => val gain = binAggregates.calculateGain(leftChildOffset, splitIdx) (splitIdx, gain) }.maxBy(_._2) + val leftChildStats = binAggregates.getImpurityCalculator( + leftChildOffset, bestFeatureSplitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + rightChildStats.subtract(leftChildStats) + val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature @@ -804,28 +798,23 @@ private[spark] object RandomForest extends Logging { // lastCategory = index of bin with total aggregates for this (node, feature) val lastCategory = categoriesSortedByCentroid.last._1 // Find best split. - val (temp, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) val (bestFeatureSplitIndex, maxGain) = Range(0, numSplits).map { case splitIdx => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val featureValue = categoriesSortedByCentroid(splitIdx)._1 val gain = binAggregates.calculateGain(nodeFeatureOffset, featureValue, lastCategory) (splitIdx, gain) }.maxBy(_._2) + val bestFeatureValue = categoriesSortedByCentroid(bestFeatureSplitIndex)._1 val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + val leftChildStats = binAggregates.getImpurityCalculator( + nodeFeatureOffset, bestFeatureValue) + val rightChildStats = binAggregates.getParentImpurityCalculator() + rightChildStats.subtract(leftChildStats) + val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) (bestFeatureSplit, bestFeatureGainStats) } }.maxBy(_._2.gain) From ea4a0735c14ff91ad1071fb517da3fd890080354 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 27 Jun 2016 17:45:36 -0700 Subject: [PATCH 4/8] Remove per feature impurityCalculator initialization --- .../spark/ml/tree/impl/RandomForest.scala | 41 +++++++------------ 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 03efd112e7f3e..40cf062ea13f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -693,7 +693,7 @@ private[spark] object RandomForest extends Logging { } // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = + val (bestSplit, bestGain, bestFeatureOffset, bestSplitIndex) = Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => val featureIndex = if (featuresForNode.nonEmpty) { featuresForNode.get.apply(featureIndexIdx) @@ -717,14 +717,8 @@ private[spark] object RandomForest extends Logging { val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits) (splitIdx, gain) }.maxBy(_._2) - val leftChildStats = binAggregates.getImpurityCalculator( - nodeFeatureOffset, bestFeatureSplitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - rightChildStats.subtract(leftChildStats) - val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + val bestFeatureSplit = splits(featureIndex)(bestFeatureSplitIndex) + (bestFeatureSplit, maxGain, nodeFeatureOffset, bestFeatureSplitIndex) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) @@ -733,13 +727,8 @@ private[spark] object RandomForest extends Logging { val gain = binAggregates.calculateGain(leftChildOffset, splitIdx) (splitIdx, gain) }.maxBy(_._2) - val leftChildStats = binAggregates.getImpurityCalculator( - leftChildOffset, bestFeatureSplitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - rightChildStats.subtract(leftChildStats) - val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + val bestFeatureSplit = splits(featureIndex)(bestFeatureSplitIndex) + (bestFeatureSplit, maxGain, leftChildOffset, bestFeatureSplitIndex) } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) @@ -809,17 +798,17 @@ private[spark] object RandomForest extends Logging { categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - val leftChildStats = binAggregates.getImpurityCalculator( - nodeFeatureOffset, bestFeatureValue) - val rightChildStats = binAggregates.getParentImpurityCalculator() - rightChildStats.subtract(leftChildStats) - val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (bestFeatureSplit, bestFeatureGainStats) + (bestFeatureSplit, maxGain, nodeFeatureOffset, bestFeatureValue) } - }.maxBy(_._2.gain) - - (bestSplit, bestSplitStats) + }.maxBy(_._2) + + val leftChildStats = binAggregates.getImpurityCalculator( + bestFeatureOffset, bestSplitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + rightChildStats.subtract(leftChildStats) + val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (bestSplit, bestFeatureGainStats) } /** From ca8b36088b74cacb7f162fb793070c4d3c6a1a8c Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 28 Jun 2016 10:27:40 -0700 Subject: [PATCH 5/8] Get rid of calculateImpurityStats --- .../spark/ml/tree/impl/RandomForest.scala | 84 ++++--------------- 1 file changed, 14 insertions(+), 70 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 40cf062ea13f5..f7978885e7069 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -613,65 +613,6 @@ private[spark] object RandomForest extends Logging { } } - /** - * Calculate the impurity statistics for a given (feature, split) based upon left/right - * aggregates. - * - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) - */ - private def calculateImpurityStats( - stats: ImpurityStats, - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): ImpurityStats = { - - val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { - leftImpurityCalculator.copy.add(rightImpurityCalculator) - } else { - stats.impurityCalculator - } - - val impurity: Double = if (stats == null) { - parentImpurityCalculator.calculate() - } else { - stats.impurity - } - - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - val totalCount = leftCount + rightCount - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) - } - /** * Find the best split for a node. * @@ -684,13 +625,7 @@ private[spark] object RandomForest extends Logging { featuresForNode: Option[Array[Int]], node: LearningNode): (Split, ImpurityStats) = { - // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level == 0) { - null - } else { - node.stats - } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestGain, bestFeatureOffset, bestSplitIndex) = @@ -802,12 +737,21 @@ private[spark] object RandomForest extends Logging { } }.maxBy(_._2) - val leftChildStats = binAggregates.getImpurityCalculator( + val leftImpurityCalculator = binAggregates.getImpurityCalculator( bestFeatureOffset, bestSplitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - rightChildStats.subtract(leftChildStats) - val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) + val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() + val rightImpurityCalculator = parentImpurityCalculator.copy.subtract( + leftImpurityCalculator) + val bestFeatureGainStats = { + if (bestGain == Double.MinValue) { + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + else { + new ImpurityStats(bestGain, parentImpurityCalculator.calculate(), + parentImpurityCalculator, leftImpurityCalculator, + rightImpurityCalculator) + } + } (bestSplit, bestFeatureGainStats) } From 67b401a6a0e59b48a167e4f3036ca9f3f6a5df1f Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 28 Jun 2016 11:17:32 -0700 Subject: [PATCH 6/8] where did that come from? --- .../apache/spark/mllib/tree/impurity/Entropy.scala | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 081d230d45864..e0938e9374658 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -40,19 +40,7 @@ object Entropy extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value, or 0 if totalCount = 0 var leftCount = 0.0 - var totalCount = 0.0 - var i = 0 - while (i < statsSize) { - leftCount += allStats(leftChildOffset + i) - totalCount += allStats(parentOffset + i) - } - val rightCount = totalCount - leftCount - - if ((leftCount < minInstancesPerNode) || - (rightCount < minInstancesPerNode)) { - return Double.MinValue - } + * @return information value, or 0 if totalCount = 0 */ @Since("1.1.0") @DeveloperApi From e8b89141f6cabfef5f582fe9521f4443afa9ec65 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 28 Jun 2016 17:01:55 -0700 Subject: [PATCH 7/8] Add documentation --- .../spark/ml/tree/impl/DTStatsAggregator.scala | 18 ++++++++++++++++++ .../spark/ml/tree/impl/RandomForest.scala | 6 +++--- .../spark/mllib/tree/impurity/Entropy.scala | 15 +++++++++++++++ .../spark/mllib/tree/impurity/Gini.scala | 15 +++++++++++++++ .../spark/mllib/tree/impurity/Impurity.scala | 15 +++++++++++++++ .../spark/mllib/tree/impurity/Variance.scala | 15 +++++++++++++++ 6 files changed, 81 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 6e7c475971b60..6a1e46049e19b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -94,6 +94,16 @@ private[spark] class DTStatsAggregator( impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + /** + * Calculate gain for a given (featureOffset, leftBin, parentBin). + * + * @param featureOffset This is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + * @param leftBinIndex Index of the leftChild in allStats + * Given by featureOffset + leftBinIndex * statsSize + * @param parentBinIndex Index of the parent in allStats + * Given by featureOffset + parentBinIndex * statsSize + */ def calculateGain( featureOffset: Int, leftBinIndex: Int, @@ -115,6 +125,14 @@ private[spark] class DTStatsAggregator( gain } + /** + * Calculate gain for a given (featureOffset, leftBin). + * The stats of the parent are inferred from parentStats. + * @param featureOffset This is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + * @param leftBinIndex Index of the leftChild in allStats + * Given by featureOffset + leftBinIndex * statsSize + */ def calculateGain( featureOffset: Int, leftBinIndex: Int): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index f7978885e7069..d3297ac1693d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -648,7 +648,7 @@ private[spark] object RandomForest extends Logging { } // Find best split. val (bestFeatureSplitIndex, maxGain) = - Range(0, numSplits).map { case splitIdx => + Range(0, numSplits).map { splitIdx => val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits) (splitIdx, gain) }.maxBy(_._2) @@ -658,7 +658,7 @@ private[spark] object RandomForest extends Logging { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, maxGain) = - Range(0, numSplits).map { case splitIdx => + Range(0, numSplits).map { splitIdx => val gain = binAggregates.calculateGain(leftChildOffset, splitIdx) (splitIdx, gain) }.maxBy(_._2) @@ -723,7 +723,7 @@ private[spark] object RandomForest extends Logging { val lastCategory = categoriesSortedByCentroid.last._1 // Find best split. val (bestFeatureSplitIndex, maxGain) = - Range(0, numSplits).map { case splitIdx => + Range(0, numSplits).map { splitIdx => val featureValue = categoriesSortedByCentroid(splitIdx)._1 val gain = binAggregates.calculateGain(nodeFeatureOffset, featureValue, lastCategory) (splitIdx, gain) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index e0938e9374658..a2e37db73515f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -80,6 +80,21 @@ object Entropy extends Impurity { @Since("1.1.0") def instance: this.type = this + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ override def calculateGain( allStats: Array[Double], leftChildOffset: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 99bd93c78d1ef..1e90ec78af7aa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -73,6 +73,21 @@ object Gini extends Impurity { @Since("1.1.0") def instance: this.type = this + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ override def calculateGain( allStats: Array[Double], leftChildOffset: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 045744a2a4367..fe1c43dc944fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -53,6 +53,21 @@ trait Impurity extends Serializable { @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ protected def calculateGain( allStats: Array[Double], leftChildOffset: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index ead2ce8c327c3..4c3837f2d8613 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -64,6 +64,21 @@ object Variance extends Impurity { @Since("1.0.0") def instance: this.type = this + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ override def calculateGain( allStats: Array[Double], leftChildOffset: Int, From af1ff66153103874e97b24282bd8a207958578da Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 28 Jun 2016 17:39:07 -0700 Subject: [PATCH 8/8] minor change to variance calculation --- .../org/apache/spark/mllib/tree/impurity/Variance.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 4c3837f2d8613..84eadcda36280 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -106,11 +106,8 @@ object Variance extends Impurity { val rightSumSquares = parentSumSquares - leftSumSquares val parentImpurity = (parentSumSquares - (parentSum * parentSum) / totalCount) / totalCount - val leftImpurity = (leftSumSquares - (leftSum * leftSum) / leftCount) / leftCount - val rightImpurity = (rightSumSquares - (rightSum * rightSum) / rightCount) / leftCount - - val leftWeighted = leftImpurity * leftCount / totalCount - val rightWeighted = rightImpurity * rightCount / totalCount + val leftWeighted = (leftSumSquares - (leftSum * leftSum) / leftCount) / totalCount + val rightWeighted = (rightSumSquares - (rightSum * rightSum) / rightCount) / totalCount val gain = parentImpurity - leftWeighted - rightWeighted if (gain < minInfoGain) {