Skip to content

Commit 03c4020

Browse files
sethahjkbradley
authored andcommitted
[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training
## What changes were proposed in this pull request? A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored. ## How was this patch tested? A unit test was added to check that finding splits for a constant feature produces an empty array. Author: sethah <[email protected]> Closes #12374 from sethah/SPARK-14610.
1 parent 29f186b commit 03c4020

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging {
705705
node.stats
706706
}
707707

708+
val validFeatureSplits =
709+
Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
710+
featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
711+
.getOrElse((featureIndexIdx, featureIndexIdx))
712+
}.withFilter { case (_, featureIndex) =>
713+
binAggregates.metadata.numSplits(featureIndex) != 0
714+
}
715+
708716
// For each (feature, split), calculate the gain, and select the best (feature, split).
709717
val (bestSplit, bestSplitStats) =
710-
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
711-
val featureIndex = if (featuresForNode.nonEmpty) {
712-
featuresForNode.get.apply(featureIndexIdx)
713-
} else {
714-
featureIndexIdx
715-
}
718+
validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
716719
val numSplits = binAggregates.metadata.numSplits(featureIndex)
717720
if (binAggregates.metadata.isContinuous(featureIndex)) {
718721
// Cumulative sum (scanLeft) of bin statistics.
@@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging {
966969
* NOTE: `metadata.numbins` will be changed accordingly
967970
* if there are not enough splits to be found
968971
* @param featureIndex feature index to find splits
969-
* @return array of splits
972+
* @return array of split thresholds
970973
*/
971974
private[tree] def findSplitsForContinuousFeature(
972975
featureSamples: Iterable[Double],
@@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging {
975978
require(metadata.isContinuous(featureIndex),
976979
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
977980

978-
val splits = {
981+
val splits = if (featureSamples.isEmpty) {
982+
Array.empty[Double]
983+
} else {
979984
val numSplits = metadata.numSplits(featureIndex)
980985

981986
// get count for each distinct value
@@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging {
987992
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
988993

989994
// if possible splits is not enough or just enough, just return all possible splits
990-
val possibleSplits = valueCounts.length
995+
val possibleSplits = valueCounts.length - 1
991996
if (possibleSplits <= numSplits) {
992-
valueCounts.map(_._1)
997+
valueCounts.map(_._1).init
993998
} else {
994999
// stride between splits
9951000
val stride: Double = numSamples.toDouble / (numSplits + 1)
@@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging {
10231028
splitsBuilder.result()
10241029
}
10251030
}
1026-
1027-
// TODO: Do not fail; just ignore the useless feature.
1028-
assert(splits.length > 0,
1029-
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
1030-
" Please remove this feature and then try again.")
1031-
10321031
splits
10331032
}
10341033

mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
115115
)
116116
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
117117
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
118-
assert(splits.length === 3)
118+
assert(splits === Array(1.0, 2.0))
119119
// check returned splits are distinct
120120
assert(splits.distinct.length === splits.length)
121121
}
@@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
129129
)
130130
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
131131
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
132-
assert(splits.length === 2)
133-
assert(splits(0) === 2.0)
134-
assert(splits(1) === 3.0)
132+
assert(splits === Array(2.0, 3.0))
135133
}
136134

137135
// find splits when most samples close to the maximum
138136
{
139137
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
140138
Map(), Set(),
141-
Array(3), Gini, QuantileStrategy.Sort,
139+
Array(2), Gini, QuantileStrategy.Sort,
142140
0, 0, 0.0, 0, 0
143141
)
144142
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
145143
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
146-
assert(splits.length === 1)
147-
assert(splits(0) === 1.0)
144+
assert(splits === Array(1.0))
148145
}
146+
147+
// find splits for constant feature
148+
{
149+
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
150+
Map(), Set(),
151+
Array(3), Gini, QuantileStrategy.Sort,
152+
0, 0, 0.0, 0, 0
153+
)
154+
val featureSamples = Array(0, 0, 0).map(_.toDouble)
155+
val featureSamplesEmpty = Array.empty[Double]
156+
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
157+
assert(splits === Array[Double]())
158+
val splitsEmpty =
159+
RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
160+
assert(splitsEmpty === Array[Double]())
161+
}
162+
}
163+
164+
test("train with constant features") {
165+
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
166+
val data = Array.fill(5)(lp)
167+
val rdd = sc.parallelize(data)
168+
val strategy = new OldStrategy(
169+
OldAlgo.Classification,
170+
Gini,
171+
maxDepth = 2,
172+
numClasses = 2,
173+
maxBins = 100,
174+
categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
175+
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
176+
assert(tree.rootNode.impurity === -1.0)
177+
assert(tree.depth === 0)
178+
assert(tree.rootNode.prediction === lp.label)
149179
}
150180

151181
test("Multiclass classification with unordered categorical features: split calculations") {

0 commit comments

Comments
 (0)