From bbdbf209e2d58917916fb62091325fd220acbf07 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 13 Apr 2016 16:27:26 -0700 Subject: [PATCH 1/9] remove extra split for continuous features --- .../spark/ml/tree/impl/RandomForest.scala | 40 +++++++++++-------- .../ml/tree/impl/RandomForestSuite.scala | 33 ++++++++++++++- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eb..b169e0a4cec2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -712,17 +712,23 @@ private[spark] object RandomForest extends Logging { splitIndex += 1 } // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + if (numSplits == 0) { + (new ContinuousSplit(featureIndex, Double.MinValue), + ImpurityStats.getInvalidImpurityStats(gainAndImpurityStats.impurityCalculator)) + } else { + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) @@ -974,9 +980,9 @@ private[spark] object RandomForest extends Logging { val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray // if possible splits is not enough or just enough, just return all possible splits - val possibleSplits = valueCounts.length + val possibleSplits = valueCounts.length - 1 if (possibleSplits <= numSplits) { - valueCounts.map(_._1) + valueCounts.map(_._1).init } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1011,10 +1017,10 @@ private[spark] object RandomForest extends Logging { } } - // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, - s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") +// // TODO: Do not fail; just ignore the useless feature. +// assert(splits.length > 0, +// s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + +// " Please remove this feature and then try again.") splits } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dcc2f305df75..8e94ade922c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -114,7 +114,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 3) + assert(splits.length === 2) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -137,7 +137,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, + Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) @@ -145,6 +145,35 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.length === 1) assert(splits(0) === 1.0) } + + // find splits for constant feature + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 0, 0).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array[Double]()) + } + } + + test("train with constant features") { + val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 5)) + val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L) + assert(tree.rootNode.impurity === -1.0) + assert(tree.depth === 0) + assert(tree.rootNode.prediction === lp.label) } test("Multiclass classification with unordered categorical features: split calculations") { From 2da84740ef7be8928ba9059c517ea94ca93d77b5 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 13 Apr 2016 16:34:08 -0700 Subject: [PATCH 2/9] cleanup --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b169e0a4cec2..93ec678351bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1016,12 +1016,6 @@ private[spark] object RandomForest extends Logging { splitsBuilder.result() } } - -// // TODO: Do not fail; just ignore the useless feature. -// assert(splits.length > 0, -// s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + -// " Please remove this feature and then try again.") - splits } From ab5694a1ec04bda19d43bc8bc563577dff65c793 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 13 Apr 2016 17:35:34 -0700 Subject: [PATCH 3/9] unit test failure --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 93ec678351bc..a61798da1e70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -714,7 +714,7 @@ private[spark] object RandomForest extends Logging { // Find best split. if (numSplits == 0) { (new ContinuousSplit(featureIndex, Double.MinValue), - ImpurityStats.getInvalidImpurityStats(gainAndImpurityStats.impurityCalculator)) + ImpurityStats.getInvalidImpurityStats(binAggregates.getParentImpurityCalculator())) } else { val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => From 8835f64c07cd1baa7c764ac4deb589f14a042de9 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 21 Jun 2016 15:56:59 -0700 Subject: [PATCH 4/9] handle empty case --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 4 +++- .../org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index a61798da1e70..70ef6ab82e89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -968,7 +968,9 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = { + val splits = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 8e94ade922c7..ec4211468b32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -154,8 +154,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { 0, 0, 0.0, 0, 0 ) val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits === Array[Double]()) + val splitsEmpty = + RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) + assert(splitsEmpty === Array[Double]()) } } From c707b256e4af8aadb821946c5a1872168f7b8db0 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 21 Jun 2016 16:19:09 -0700 Subject: [PATCH 5/9] add instrumentation variable to test --- .../scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index ec4211468b32..f54a1b8916ee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -174,7 +174,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 5)) - val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L) + val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) assert(tree.rootNode.prediction === lp.label) From 3bb28fe671b7f2aeac486ea8eb87809a997e6245 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 11 Jul 2016 11:00:09 -0700 Subject: [PATCH 6/9] address some review comments --- .../spark/ml/tree/impl/RandomForest.scala | 22 ++++++++++++------- .../ml/tree/impl/RandomForestSuite.scala | 11 ++++------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 70ef6ab82e89..03530cc96fa6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -692,14 +692,20 @@ private[spark] object RandomForest extends Logging { node.stats } - // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) + val validFeatureSplits = + Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => + if (featuresForNode.nonEmpty) { + (featureIndexIdx, featuresForNode.get.apply(featureIndexIdx)) } else { - featureIndexIdx + (featureIndexIdx, featureIndexIdx) } + }.withFilter { case (featureIndexIdx, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } + + // For each (feature, split), calculate the gain, and select the best (feature, split). + val (bestSplit, bestSplitStats) = + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => val numSplits = binAggregates.metadata.numSplits(featureIndex) if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -712,7 +718,7 @@ private[spark] object RandomForest extends Logging { splitIndex += 1 } // Find best split. - if (numSplits == 0) { + if (numSplits == 0 && false) { (new ContinuousSplit(featureIndex, Double.MinValue), ImpurityStats.getInvalidImpurityStats(binAggregates.getParentImpurityCalculator())) } else { @@ -959,7 +965,7 @@ private[spark] object RandomForest extends Logging { * NOTE: `metadata.numbins` will be changed accordingly * if there are not enough splits to be found * @param featureIndex feature index to find splits - * @return array of splits + * @return array of split thresholds */ private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index f54a1b8916ee..5478d8845d03 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -114,7 +114,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) + assert(splits === Array(1.0, 2.0)) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -128,9 +128,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) - assert(splits(0) === 2.0) - assert(splits(1) === 3.0) + assert(splits === Array(2.0, 3.0)) } // find splits when most samples close to the maximum @@ -142,8 +140,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 1) - assert(splits(0) === 1.0) + assert(splits === Array(1.0)) } // find splits for constant feature @@ -173,7 +170,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1 -> 5)) + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) From eddac63af93c0ba8ff0ffb218029263136253c24 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 19 Jul 2016 16:41:17 -0700 Subject: [PATCH 7/9] cleanup --- .../spark/ml/tree/impl/RandomForest.scala | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 03530cc96fa6..711041eb3a04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -718,23 +718,18 @@ private[spark] object RandomForest extends Logging { splitIndex += 1 } // Find best split. - if (numSplits == 0 && false) { - (new ContinuousSplit(featureIndex, Double.MinValue), - ImpurityStats.getInvalidImpurityStats(binAggregates.getParentImpurityCalculator())) - } else { - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) From 3c73726ae986518bb9ef6fb4b0eb0fdb8e7e6616 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 25 Jul 2016 10:38:36 -0700 Subject: [PATCH 8/9] update numClasses in test --- .../scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 5478d8845d03..c88bceb5179e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -168,7 +168,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { OldAlgo.Classification, Gini, maxDepth = 2, - numClasses = 100, + numClasses = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) From 928a834a50d6921c825021fffc6f2a810122cb6c Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 10 Oct 2016 13:10:35 -0700 Subject: [PATCH 9/9] small cleanups --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 711041eb3a04..43867f92405f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -694,12 +694,9 @@ private[spark] object RandomForest extends Logging { val validFeatureSplits = Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => - if (featuresForNode.nonEmpty) { - (featureIndexIdx, featuresForNode.get.apply(featureIndexIdx)) - } else { - (featureIndexIdx, featureIndexIdx) - } - }.withFilter { case (featureIndexIdx, featureIndex) => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => binAggregates.metadata.numSplits(featureIndex) != 0 } @@ -720,8 +717,7 @@ private[spark] object RandomForest extends Logging { // Find best split. val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats)