Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -692,14 +692,17 @@ private[spark] object RandomForest extends Logging {
node.stats
}

val validFeatureSplits =
Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
.getOrElse((featureIndexIdx, featureIndexIdx))
}.withFilter { case (_, featureIndex) =>
binAggregates.metadata.numSplits(featureIndex) != 0
}

// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
Expand Down Expand Up @@ -953,7 +956,7 @@ private[spark] object RandomForest extends Logging {
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
* @return array of splits
* @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Iterable[Double],
Expand All @@ -962,7 +965,9 @@ private[spark] object RandomForest extends Logging {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

val splits = {
val splits = if (featureSamples.isEmpty) {
Array.empty[Double]
} else {
val numSplits = metadata.numSplits(featureIndex)

// get count for each distinct value
Expand All @@ -974,9 +979,9 @@ private[spark] object RandomForest extends Logging {
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray

// if possible splits is not enough or just enough, just return all possible splits
val possibleSplits = valueCounts.length
val possibleSplits = valueCounts.length - 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of possible bins should be valueCounts.length, and the number of possible splits should therefore be valueCounts.length - 1.

if (possibleSplits <= numSplits) {
valueCounts.map(_._1)
valueCounts.map(_._1).init
} else {
// stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1)
Expand Down Expand Up @@ -1010,12 +1015,6 @@ private[spark] object RandomForest extends Logging {
splitsBuilder.result()
}
}

// TODO: Do not fail; just ignore the useless feature.
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")

splits
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 3)
assert(splits === Array(1.0, 2.0))
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
Expand All @@ -128,23 +128,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
assert(splits(0) === 2.0)
assert(splits(1) === 3.0)
assert(splits === Array(2.0, 3.0))
}

// find splits when most samples close to the maximum
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
Array(2), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 1)
assert(splits(0) === 1.0)
assert(splits === Array(1.0))
}

// find splits for constant feature
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 0, 0).map(_.toDouble)
val featureSamplesEmpty = Array.empty[Double]
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array[Double]())
val splitsEmpty =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it is not currently possible for the findSplitsForContinuousFeature method to receive an empty array, we still handle it. I appreciate feedback on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this ever happen, or in other words what corner case does this fix?

Copy link
Contributor Author

@sethah sethah Jul 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iterable passed to findSplitsForContinuousFeature is the result of a groupByKey, so I don't see how this could ever happen. Still, if this method is used elsewhere ever in the future, it would fail with java.lang.UnsupportedOperationException if it received an empty iterable. I'm open to changing this.

RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
assert(splitsEmpty === Array[Double]())
}
}

test("train with constant features") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would have failed before due to the assertion that splits.length > 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"train with constant features" -> "train with constant continuous features"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is not specific to continuous features.

val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data)
val strategy = new OldStrategy(
OldAlgo.Classification,
Gini,
maxDepth = 2,
numClasses = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
assert(tree.rootNode.prediction === lp.label)
}

test("Multiclass classification with unordered categorical features: split calculations") {
Expand Down