diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..43867f92405fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -692,14 +692,17 @@ private[spark] object RandomForest extends Logging { node.stats } + val validFeatureSplits = + Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } + // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => val numSplits = binAggregates.metadata.numSplits(featureIndex) if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -953,7 +956,7 @@ private[spark] object RandomForest extends Logging { * NOTE: `metadata.numbins` will be changed accordingly * if there are not enough splits to be found * @param featureIndex feature index to find splits - * @return array of splits + * @return array of split thresholds */ private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], @@ -962,7 +965,9 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = { + val splits = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value @@ -974,9 +979,9 @@ private[spark] object RandomForest extends Logging { val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray // if possible splits is not enough or just enough, just return all possible splits - val possibleSplits = valueCounts.length + val possibleSplits = valueCounts.length - 1 if (possibleSplits <= numSplits) { - valueCounts.map(_._1) + valueCounts.map(_._1).init } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1010,12 +1015,6 @@ private[spark] object RandomForest extends Logging { splitsBuilder.result() } } - - // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, - s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") - splits } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dcc2f305df75a..c88bceb5179e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -114,7 +114,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 3) + assert(splits === Array(1.0, 2.0)) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -128,23 +128,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) - assert(splits(0) === 2.0) - assert(splits(1) === 3.0) + assert(splits === Array(2.0, 3.0)) } // find splits when most samples close to the maximum { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, + Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 1) - assert(splits(0) === 1.0) + assert(splits === Array(1.0)) } + + // find splits for constant feature + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamplesEmpty = Array.empty[Double] + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array[Double]()) + val splitsEmpty = + RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) + assert(splitsEmpty === Array[Double]()) + } + } + + test("train with constant features") { + val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) + assert(tree.rootNode.impurity === -1.0) + assert(tree.depth === 0) + assert(tree.rootNode.prediction === lp.label) } test("Multiclass classification with unordered categorical features: split calculations") {