From def90521a01b5b677497d595ba3e1ed1d60da353 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 30 Aug 2015 01:49:35 -0700 Subject: [PATCH 1/4] Some progress towards not including one category features --- .../tree/impl/DecisionTreeMetadata.scala | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 21ee49c45788c..bb2f043a72546 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -51,7 +51,8 @@ private[spark] class DecisionTreeMetadata( val minInstancesPerNode: Int, val minInfoGain: Double, val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { + val numFeaturesPerNode: Int, + val featureIndexes: IndexedSeq[Int]) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -111,6 +112,14 @@ private[spark] object DecisionTreeMetadata extends Logging { throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + s"but was given by empty one.") } + // Construct the feature indexes that we can use (one category features are not useful) + val featureIndexes = if (strategy.categoricalFeaturesInfo.nonEmpty) { + val singleCategoryIndexes = strategy.categoricalFeaturesInfo.filter(_._2 < 2).map(_._1).toSet + 0.to(numFeatures).filterNot(singleCategoryIndexes.contains) + } else { + 0.to(numFeatures) + } + val numActiveFeatures = featureIndexes.size val numExamples = input.count() val numClasses = strategy.algo match { case Classification => strategy.numClasses @@ -144,8 +153,7 @@ private[spark] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Hack: If a categorical feature has only 1 category, we treat it as continuous. - // TODO(SPARK-9957): Handle this properly by filtering out those features. + // Set number of bins to -1 if we are skipping a feature if (numCategories > 1) { // Decide if some categorical features should be treated as unordered features, // which require 2 * ((1 << numCategories - 1) - 1) bins. @@ -157,14 +165,18 @@ private[spark] object DecisionTreeMetadata extends Logging { } else { numBins(featureIndex) = numCategories } + } else { + numBins(featureIndex) = -1 } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + // Set number of bins to -1 if we are skipping a feature if (numCategories > 1) { numBins(featureIndex) = numCategories + } else { + numBins(featureIndex) = -1 } } } @@ -184,16 +196,17 @@ private[spark] object DecisionTreeMetadata extends Logging { case _ => featureSubsetStrategy } val numFeaturesPerNode: Int = _featureSubsetStrategy match { - case "all" => numFeatures - case "sqrt" => math.sqrt(numFeatures).ceil.toInt - case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) - case "onethird" => (numFeatures / 3.0).ceil.toInt + case "all" => numActiveFeatures + case "sqrt" => math.sqrt(numActiveFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numActiveFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numActiveFeatures / 3.0).ceil.toInt } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode, + featureIndexes) } /** From 446af7a0ccab2e50abc6ce532c08618c5bb28d27 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 30 Aug 2015 21:23:55 -0700 Subject: [PATCH 2/4] Fix the test compile --- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 356d957f15909..738a4884d3109 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -113,7 +113,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, 0.to(5).toArray ) val featureSamples = Array.fill(200000)(math.random) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -130,7 +130,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, 0.to(4).toArray ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -146,7 +146,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, 0.to(2).toArray ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -162,7 +162,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, 0.to(2).toArray ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) From 93949b381dac4b8a9bbdad25ec3b4699a545ad50 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 30 Aug 2015 22:11:37 -0700 Subject: [PATCH 3/4] Fix some tests to have the same expected array for active features we would expect, add some scaladoc to clarify numFeatures now that its a bit trixie --- .../apache/spark/ml/tree/impl/RandomForest.scala | 16 ++++------------ .../mllib/tree/impl/DecisionTreeMetadata.scala | 6 ++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 8 ++++---- 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4ac51a475474a..c13acbb6c9122 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -329,13 +329,9 @@ private[ml] object RandomForest extends Logging { featureIndexIdx += 1 } } else { - // Use all features - val numFeatures = agg.metadata.numFeatures - var featureIndex = 0 - while (featureIndex < numFeatures) { + agg.metadata.featureIndexes.foreach{featureIndex => val binIndex = treePoint.binnedFeatures(featureIndex) agg.update(featureIndex, binIndex, label, instanceWeight) - featureIndex += 1 } } } @@ -863,10 +859,8 @@ private[ml] object RandomForest extends Logging { logDebug("isMulticlass = " + metadata.isMulticlass) - val numFeatures = metadata.numFeatures - // Sample the input only if there are continuous features. - val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) + val hasContinuousFeatures = metadata.featureIndexes.exists(metadata.isContinuous) val sampledInput = if (hasContinuousFeatures) { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) @@ -881,12 +875,11 @@ private[ml] object RandomForest extends Logging { new Array[LabeledPoint](0) } - val splits = new Array[Array[Split]](numFeatures) + val splits = new Array[Array[Split]](metadata.numFeatures) // Find all splits. // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { + metadata.featureIndexes.foreach{featureIndex => if (metadata.isContinuous(featureIndex)) { val featureSamples = sampledInput.map(_.features(featureIndex)) val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex) @@ -927,7 +920,6 @@ private[ml] object RandomForest extends Logging { splits(featureIndex) = new Array[Split](0) } } - featureIndex += 1 } splits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index bb2f043a72546..30acf83aceeb2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -30,12 +30,14 @@ import org.apache.spark.rdd.RDD /** * Learning and dataset metadata for DecisionTree. * + * @param numFeatures Total number of features (including single class) * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. * For regression: fixed at 0 (no meaning). * @param maxBins Maximum number of bins, for all features. * @param featureArity Map: categorical feature index --> arity. * I.e., the feature takes values in {0, ..., arity - 1}. * @param numBins Number of bins for each feature. + * @param featureIndexes Indexes of usable (e.g non-single-class) features. */ private[spark] class DecisionTreeMetadata( val numFeatures: Int, @@ -115,9 +117,9 @@ private[spark] object DecisionTreeMetadata extends Logging { // Construct the feature indexes that we can use (one category features are not useful) val featureIndexes = if (strategy.categoricalFeaturesInfo.nonEmpty) { val singleCategoryIndexes = strategy.categoricalFeaturesInfo.filter(_._2 < 2).map(_._1).toSet - 0.to(numFeatures).filterNot(singleCategoryIndexes.contains) + 0.to(numFeatures-1).filterNot(singleCategoryIndexes.contains) } else { - 0.to(numFeatures) + 0.to(numFeatures-1) } val numActiveFeatures = featureIndexes.size val numExamples = input.count() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 738a4884d3109..12494d5444cfb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -113,7 +113,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, 0.to(5).toArray + 0, 0, 0.0, 0, 0, 0.to(0).toArray ) val featureSamples = Array.fill(200000)(math.random) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -130,7 +130,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, 0.to(4).toArray + 0, 0, 0.0, 0, 0, 0.to(0).toArray ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -146,7 +146,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, 0.to(2).toArray + 0, 0, 0.0, 0, 0, 0.to(0).toArray ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -162,7 +162,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, 0.to(2).toArray + 0, 0, 0.0, 0, 0, 0.to(0).toArray ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) From fef404b331fd1c16898a690c31feb2336fdbfda5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 10 Nov 2015 22:39:02 -0800 Subject: [PATCH 4/4] progress --- .../spark/mllib/tree/DecisionTree.scala | 18 ++---------- .../mllib/tree/impl/DTStatsAggregator.scala | 4 +++ .../tree/impl/DecisionTreeMetadata.scala | 6 ++-- .../spark/mllib/tree/RandomForestSuite.scala | 29 +++++++++++++++++-- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index af1f7e74c004d..d8ffaf0c9601a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -347,21 +347,8 @@ object DecisionTree extends Serializable with Logging { unorderedFeatures: Set[Int], instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { - val numFeaturesPerNode = if (featuresForNode.nonEmpty) { - // Use subsampled features - featuresForNode.get.size - } else { - // Use all features - agg.metadata.numFeatures - } - // Iterate over features. - var featureIndexIdx = 0 - while (featureIndexIdx < numFeaturesPerNode) { - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } + val features: Array[Int] = featuresForNode.getOrElse(agg.metadata.featureIndexes.toArray) + features.zipWithIndex.foreach{case (featureIndex, featureIndexIdx) => if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) @@ -385,7 +372,6 @@ object DecisionTree extends Serializable with Logging { val binIndex = treePoint.binnedFeatures(featureIndex) agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) } - featureIndexIdx += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 7985ed4b4c0fa..dc9f17e4a138d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -96,7 +96,11 @@ private[spark] class DTStatsAggregator( * Update the stats for a given (feature, bin) for ordered features, using the given label. */ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + println("Updating for "+featureIndex+"binIndex"+binIndex) val i = featureOffsets(featureIndex) + binIndex * statsSize + println("featureOffsets are "+featureOffsets.toList) + println("featureOffsets is "+featureOffsets(featureIndex)) + println("i "+i) impurityAggregator.update(allStats, i, label, instanceWeight) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 30acf83aceeb2..d6e96bc83526b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -168,17 +168,17 @@ private[spark] object DecisionTreeMetadata extends Logging { numBins(featureIndex) = numCategories } } else { - numBins(featureIndex) = -1 + numBins(featureIndex) = 0 } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Set number of bins to -1 if we are skipping a feature + // Set number of bins to 0 if we are skipping a feature if (numCategories > 1) { numBins(featureIndex) = numCategories } else { - numBins(featureIndex) = -1 + numBins(featureIndex) = 0 } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index e6df5d974bf36..745a70de5c371 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -53,7 +53,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - +/* test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -196,7 +196,32 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) } + */ + test("filtering of 1 category categorical point") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0, 0.0, 3.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0, 0.0, 6.0, 3.0)) + arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + val categoricalFeaturesInfo = Map(0 -> 1, 2 -> 2, 4 -> 4) + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) + val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, + featureSubsetStrategy = "sqrt", seed = 12345) + // TODO(holden): go through and make sure that none of the trees have the 0 feature used in them. + def assertTreeDoesNotContain(node: Node, feature: Long): Unit = { + node.split.foreach(split => assert(split.feature != feature)) + node.leftNode.foreach(assertTreeDoesNotContain(_, feature)) + node.rightNode.foreach(assertTreeDoesNotContain(_, feature)) + } + model.trees.foreach{tree => + + } + } + +/* test("subsampling rate in RandomForest"){ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) val rdd = sc.parallelize(arr) @@ -233,5 +258,5 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } } } - + */ }