From ac4237808090237fe4c562da8c88c55c330d451f Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 11:17:58 +0800 Subject: [PATCH 01/16] add min info gain and min instances per node parameters in decision tree --- .../spark/mllib/tree/DecisionTree.scala | 23 +++++++++-- .../mllib/tree/configuration/Strategy.scala | 2 + .../tree/impl/DecisionTreeMetadata.scala | 7 +++- .../tree/model/InformationGainStats.scala | 5 +++ .../apache/spark/mllib/tree/model/Split.scala | 6 +++ .../spark/mllib/tree/DecisionTreeSuite.scala | 41 ++++++++++++++++++- 6 files changed, 77 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dd766c12d28a4..2070fa7efd664 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -738,12 +738,15 @@ object DecisionTree extends Serializable with Logging { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count - val totalCount = leftCount + rightCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return InformationGainStats.invalidInformationGainStats } + val totalCount = leftCount + rightCount + val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) // impurity of parent node @@ -763,6 +766,9 @@ object DecisionTree extends Serializable with Logging { val rightWeight = rightCount / totalCount.toDouble val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + if (gain < metadata.minInfoGain) { + return InformationGainStats.invalidInformationGainStats + } new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) } @@ -807,6 +813,9 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature @@ -820,6 +829,9 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature @@ -891,6 +903,9 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index cfc8192a85abd..48b958fb95975 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -61,6 +61,8 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val minInstancesPerNode: Int = 0, + val minInfoGain: Double = 0.0, val maxMemoryInMB: Int = 128) extends Serializable { if (algo == Classification) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index e95add7558bcf..5ceaa8154d11a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata( val unorderedFeatures: Set[Int], val numBins: Array[Int], val impurity: Impurity, - val quantileStrategy: QuantileStrategy) extends Serializable { + val quantileStrategy: QuantileStrategy, + val minInstancesPerNode: Int, + val minInfoGain: Double) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy) + strategy.impurity, strategy.quantileCalculationStrategy, + strategy.minInstancesPerNode, strategy.minInfoGain) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index fb12298e0f5d3..dce9d2ec8a5f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -43,3 +43,8 @@ class InformationGainStats( .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) } } + + +private[tree] object InformationGainStats { + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 50fb48b40de3d..da1aefb01c6ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType /** * :: DeveloperApi :: @@ -66,3 +68,7 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) + +private[tree] object Split { + val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List()) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8e556c917b2e7..bb3d1e03c69a4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model.{Split, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext @@ -684,6 +684,45 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.6) } + test("split must satisfy min instances per node requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, minInstancesPerNode = 4) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + assert(model.topNode.split.get == Split.noSplit) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + } + + test("split must satisfy min info gain requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, minInfoGain = 1.0) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + assert(model.topNode.split.get == Split.noSplit) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + } } object DecisionTreeSuite { From ff34845c8e43f5b9755dd1fdf428be8b2284c68b Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 12:29:12 +0800 Subject: [PATCH 02/16] separate calculation of predict of node from calculation of info gain --- .../spark/mllib/tree/DecisionTree.scala | 47 +++++++++++++------ .../tree/model/InformationGainStats.scala | 12 ++--- .../spark/mllib/tree/model/Predict.scala | 36 ++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 16 ++++--- 4 files changed, 82 insertions(+), 29 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2070fa7efd664..eb491b2dbd101 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.start("findBestSplits") - val splitsStatsForLevel: Array[(Split, InformationGainStats)] = + val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") @@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("extractNodeInfo") val split = nodeSplitStats._1 val stats = nodeSplitStats._2 + val predict = nodeSplitStats._3 val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node timer.stop("extractNodeInfo") @@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, @@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging { // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) - var bestSplits = new Array[(Split, InformationGainStats)](0) + var bestSplits = new Array[(Split, InformationGainStats, Predict)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { @@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]], timer: TimeTracker, numGroups: Int = 1, - groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { + groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes) // Iterating over all nodes at this level var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -747,18 +748,16 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) + // impurity of parent node val impurity = if (level > 0) { topImpurity } else { + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) parentNodeAgg.calculate() } - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -770,7 +769,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + } + + private def calculatePredict( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator): Predict = { + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + val predict = parentNodeAgg.predict + val prob = parentNodeAgg.prob(predict) + + new Predict(predict, prob) } /** @@ -786,12 +796,14 @@ object DecisionTree extends Serializable with Logging { nodeImpurity: Double, level: Int, metadata: DecisionTreeMetadata, - splits: Array[Array[Split]]): (Split, InformationGainStats) = { + splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { logDebug("node impurity = " + nodeImpurity) + var predict: Option[Predict] = None + // For each (feature, split), calculate the gain, and select the best (feature, split). - Range(0, metadata.numFeatures).map { featureIndex => + val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex => val numSplits = metadata.numSplits(featureIndex) if (metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -809,6 +821,7 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) @@ -825,6 +838,7 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) @@ -899,6 +913,7 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) @@ -913,6 +928,10 @@ object DecisionTree extends Serializable with Logging { (bestFeatureSplit, bestFeatureGainStats) } }.maxBy(_._2.gain) + + require(predict.isDefined, "must calculate predict for each node") + + (bestSplit, bestSplitStats, predict.get) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index dce9d2ec8a5f2..4a133e21f461a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,25 +26,21 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity - * @param predict predicted value - * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double, - val predict: Double, - val prob: Double = 0.0) extends Serializable { + val rightImpurity: Double) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" + .format(gain, impurity, leftImpurity, rightImpurity) } } private[tree] object InformationGainStats { - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala new file mode 100644 index 0000000000000..6a9e9a1dc5568 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Predicted value for a node + * @param predict predicted value + * @param prob probability of the label (classification only) + */ +@DeveloperApi +class Predict( + val predict: Double, + val prob: Double = 0.0) extends Serializable{ + + override def toString = { + "predict = %f, prob = %f".format(predict, prob) + } +} \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bb3d1e03c69a4..a8127579261a1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -280,9 +280,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 + val predict = bestSplits(0)._3 assert(stats.gain > 0) - assert(stats.predict === 1) - assert(stats.prob === 0.6) + assert(predict.predict === 1) + assert(predict.prob === 0.6) assert(stats.impurity > 0.2) } @@ -313,8 +314,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 + val predict = bestSplits(0)._3.predict assert(stats.gain > 0) - assert(stats.predict === 0.6) + assert(predict === 0.6) assert(stats.impurity > 0.2) } @@ -392,7 +394,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + assert(bestSplits(0)._3.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -421,7 +423,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 0) + assert(bestSplits(0)._3.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -450,7 +452,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + assert(bestSplits(0)._3.predict === 1) } test("Second level node building with vs. without groups") { @@ -501,7 +503,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) - assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict) } } From 987cbf4b177f29e232bf2ba2ca595ea7015694da Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 12:30:01 +0800 Subject: [PATCH 03/16] fix bug --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index eb491b2dbd101..499fa2dff7be6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -145,7 +145,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val stats = nodeSplitStats._2 val predict = nodeSplitStats._3 val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node timer.stop("extractNodeInfo") From f195e830a94097e5d6d42f22c67c32ca8900d848 Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 14:04:20 +0800 Subject: [PATCH 04/16] fix style --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 499fa2dff7be6..23cf6bce6dcdc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -143,9 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("extractNodeInfo") val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val predict = nodeSplitStats._3 + val predict = nodeSplitStats._3.predict val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node timer.stop("extractNodeInfo") @@ -735,14 +735,13 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double, level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { - val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { + (rightCount < metadata.minInstancesPerNode)) { return InformationGainStats.invalidInformationGainStats } From 845c6fa58c00bfba426e56e71eb46a6f8c3f5985 Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 14:05:37 +0800 Subject: [PATCH 05/16] fix style --- .../main/scala/org/apache/spark/mllib/tree/model/Predict.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 6a9e9a1dc5568..b28c2f7671e54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -33,4 +33,4 @@ class Predict( override def toString = { "predict = %f, prob = %f".format(predict, prob) } -} \ No newline at end of file +} From e72c7e4d0ad015fdf25ea2959bdbf524056e38ca Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 14:52:24 +0800 Subject: [PATCH 06/16] add comments --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 13 ++++++++++++- .../mllib/tree/model/InformationGainStats.scala | 5 +++++ .../org/apache/spark/mllib/tree/model/Split.scala | 6 +++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 23cf6bce6dcdc..03f9cbdb9d0a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -739,7 +739,7 @@ object DecisionTree extends Serializable with Logging { val rightCount = rightImpurityCalculator.count // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats + // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { return InformationGainStats.invalidInformationGainStats @@ -764,6 +764,9 @@ object DecisionTree extends Serializable with Logging { val rightWeight = rightCount / totalCount.toDouble val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { return InformationGainStats.invalidInformationGainStats } @@ -771,6 +774,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) } + /** + * Calculate predict value for current node, given stats of any split. + * Note that this function is called only once for each node. + * @param leftImpurityCalculator left node aggregates for a split + * @param rightImpurityCalculator right node aggregates for a node + * @return predict value for current node + */ private def calculatePredict( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator): Predict = { @@ -799,6 +809,7 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + nodeImpurity) + // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 4a133e21f461a..f3e2619bd8ba0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -42,5 +42,10 @@ class InformationGainStats( private[tree] object InformationGainStats { + /** + * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index da1aefb01c6ef..91ce51a3dfdbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -68,7 +68,11 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) - private[tree] object Split { + /** + * A [[org.apache.spark.mllib.tree.model.Split]] object to denote that + * we can't find a valid split that satisfies minimum info gain + * or minimum number of instances per node. + */ val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List()) } From 46b891fd7f30b9f2d439134931b35dab387fe2b1 Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 16:09:34 +0800 Subject: [PATCH 07/16] fix bug --- .../spark/mllib/tree/DecisionTree.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 03f9cbdb9d0a7..3377e92805436 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -836,10 +836,11 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) - if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + if (bestFeatureGainStats.gain < metadata.minInfoGain) { (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } else { + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = @@ -855,8 +856,9 @@ object DecisionTree extends Serializable with Logging { }.maxBy(_._2.gain) if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } else { + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) @@ -930,12 +932,13 @@ object DecisionTree extends Serializable with Logging { }.maxBy(_._2.gain) if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } else { + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) - (bestFeatureSplit, bestFeatureGainStats) } }.maxBy(_._2.gain) From cadd569cf64d6eb7b9c9979a5066a2f63f15fed9 Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 16:48:51 +0800 Subject: [PATCH 08/16] add api docs --- .../apache/spark/mllib/tree/configuration/Strategy.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 48b958fb95975..e9e7247cf9274 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -49,6 +49,12 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param minInstancesPerNode Minimum number of nodes each child must have after split. Default value is 0. + * If a split cause left or right child to have less than minInstancesPerNode, + * this split will not be considered as a valid split. + * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. + * If a split has less information gain than minInfoGain, + * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. */ From 6728fad304511030611c61592b1a590214e7f434 Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 17:16:27 +0800 Subject: [PATCH 09/16] minor fix: remove empty lines --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8d170563c7f3c..425f50bcb137e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -747,7 +747,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - // impurity of parent node val impurity = if (level > 0) { topImpurity @@ -836,7 +835,7 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) - if (bestFeatureGainStats.gain < metadata.minInfoGain) { + if (bestFeatureGainStats.gain == metadata.minInfoGain) { (Split.noSplit, InformationGainStats.invalidInformationGainStats) } else { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) From 10b801269864cda2c00159518688942b1985061b Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 18:10:24 +0800 Subject: [PATCH 10/16] fix style --- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3a55ec630e04d..9eeda1c9040fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -49,8 +49,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param minInstancesPerNode Minimum number of nodes each child must have after split. Default value is 0. - * If a split cause left or right child to have less than minInstancesPerNode, + * @param minInstancesPerNode Minimum number of nodes each child must have after split. + * Default value is 0. If a split cause left or right child + * to have less than minInstancesPerNode, * this split will not be considered as a valid split. * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. * If a split has less information gain than minInfoGain, From efcc7369f7f52de2810446c6fb976ab1743a63cf Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Tue, 9 Sep 2014 20:33:37 +0800 Subject: [PATCH 11/16] fix bug --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 425f50bcb137e..44d1446609c0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -835,7 +835,7 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) - if (bestFeatureGainStats.gain == metadata.minInfoGain) { + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { (Split.noSplit, InformationGainStats.invalidInformationGainStats) } else { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) From d593ec70d70b633b72e260c38e89d87ab14fcd69 Mon Sep 17 00:00:00 2001 From: chouqin Date: Wed, 10 Sep 2014 07:57:27 +0800 Subject: [PATCH 12/16] fix docs and change minInstancesPerNode to 1 --- .../apache/spark/mllib/tree/configuration/Strategy.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 9eeda1c9040fb..987fe632c91ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -49,8 +49,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param minInstancesPerNode Minimum number of nodes each child must have after split. - * Default value is 0. If a split cause left or right child + * @param minInstancesPerNode Minimum number of instances each child must have after split. + * Default value is 1. If a split cause left or right child * to have less than minInstancesPerNode, * this split will not be considered as a valid split. * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. @@ -68,7 +68,7 @@ class Strategy ( val maxBins: Int = 32, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val minInstancesPerNode: Int = 0, + val minInstancesPerNode: Int = 1, val minInfoGain: Double = 0.0, val maxMemoryInMB: Int = 256) extends Serializable { From 0278a1198017aae578be3109a8311abc1f9a8e14 Mon Sep 17 00:00:00 2001 From: chouqin Date: Wed, 10 Sep 2014 10:31:57 +0800 Subject: [PATCH 13/16] remove `noSplit` and set `Predict` private to tree --- .../spark/mllib/tree/DecisionTree.scala | 26 +++++-------------- .../spark/mllib/tree/model/Predict.scala | 2 +- .../apache/spark/mllib/tree/model/Split.scala | 8 ------ .../spark/mllib/tree/DecisionTreeSuite.scala | 26 ++++++++++++++++--- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 44d1446609c0a..98596569b8c95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -835,11 +835,7 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) - if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { - (Split.noSplit, InformationGainStats.invalidInformationGainStats) - } else { - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = @@ -853,11 +849,7 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) - if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { - (Split.noSplit, InformationGainStats.invalidInformationGainStats) - } else { - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) @@ -929,15 +921,11 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) - if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { - (Split.noSplit, InformationGainStats.invalidInformationGainStats) - } else { - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) - (bestFeatureSplit, bestFeatureGainStats) - } + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } }.maxBy(_._2.gain) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index b28c2f7671e54..6fac2be2797bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.DeveloperApi * @param prob probability of the label (classification only) */ @DeveloperApi -class Predict( +private[tree] class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable{ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 91ce51a3dfdbc..b7a85f58544a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -68,11 +68,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) -private[tree] object Split { - /** - * A [[org.apache.spark.mllib.tree.model.Split]] object to denote that - * we can't find a valid split that satisfies minimum info gain - * or minimum number of instances per node. - */ - val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List()) -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a61cc8934f952..74be91fd9e1b9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{Split, DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -689,11 +689,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) assert(model.topNode.predict == 0.0) - assert(model.topNode.split.get == Split.noSplit) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) } + + // test for findBestSplits when no valid split can be found + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestInfoStats = bestSplits(0)._2 + assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } test("split must satisfy min info gain requirements") { @@ -709,11 +719,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) assert(model.topNode.predict == 0.0) - assert(model.topNode.split.get == Split.noSplit) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) } + + // test for findBestSplits when no valid split can be found + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestInfoStats = bestSplits(0)._2 + assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } } From 39f9b60907050b4e1c78f7413282df13b7e6552c Mon Sep 17 00:00:00 2001 From: chouqin Date: Wed, 10 Sep 2014 22:15:46 +0800 Subject: [PATCH 14/16] change edge `minInstancesPerNode` to 2 and add one more test --- .../spark/mllib/tree/DecisionTreeSuite.scala | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 74be91fd9e1b9..16af62da6f1e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -683,8 +683,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, minInstancesPerNode = 4) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) @@ -701,11 +701,37 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) + assert(bestSplits.length == 1) val bestInfoStats = bestSplits(0)._2 assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } + test("don't chose split that doesn't satify min instance per node requirements") { + // if a split doesn't satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + numClassesForClassification = 2, minInstancesPerNode = 2) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length == 1) + val bestSplit = bestSplits(0)._1 + val bestSplitStats = bestSplits(0)._1 + assert(bestSplit.feature == 1) + assert(bestSplitStats != InformationGainStats.invalidInformationGainStats) + } + test("split must satisfy min info gain requirements") { val arr = new Array[LabeledPoint](3) arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) @@ -731,7 +757,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) + assert(bestSplits.length == 1) val bestInfoStats = bestSplits(0)._2 assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } From c7ebaf1721ba414ed1539bfc4721c3bbfd70b77a Mon Sep 17 00:00:00 2001 From: chouqin Date: Wed, 10 Sep 2014 22:27:08 +0800 Subject: [PATCH 15/16] fix typo --- .../scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 16af62da6f1e4..ac0f31b538eda 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -706,7 +706,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } - test("don't chose split that doesn't satify min instance per node requirements") { + test("don't chose split that doesn't satisfy min instance per node requirements") { // if a split doesn't satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = new Array[LabeledPoint](4) From f1d11d15fe519f9ef9d4e1158b309dc6af38864e Mon Sep 17 00:00:00 2001 From: chouqin Date: Wed, 10 Sep 2014 22:30:22 +0800 Subject: [PATCH 16/16] fix typo --- .../scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index ac0f31b538eda..fd8547c1660fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -706,7 +706,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } - test("don't chose split that doesn't satisfy min instance per node requirements") { + test("don't choose split that doesn't satisfy min instance per node requirements") { // if a split doesn't satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = new Array[LabeledPoint](4)