-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training #12374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bbdbf20
2da8474
ab5694a
8835f64
c707b25
3bb28fe
eddac63
3c73726
928a834
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,7 +114,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| ) | ||
| val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) | ||
| val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
| assert(splits.length === 3) | ||
| assert(splits === Array(1.0, 2.0)) | ||
| // check returned splits are distinct | ||
| assert(splits.distinct.length === splits.length) | ||
| } | ||
|
|
@@ -128,23 +128,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| ) | ||
| val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) | ||
| val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
| assert(splits.length === 2) | ||
| assert(splits(0) === 2.0) | ||
| assert(splits(1) === 3.0) | ||
| assert(splits === Array(2.0, 3.0)) | ||
| } | ||
|
|
||
| // find splits when most samples close to the maximum | ||
| { | ||
| val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, | ||
| Map(), Set(), | ||
| Array(3), Gini, QuantileStrategy.Sort, | ||
| Array(2), Gini, QuantileStrategy.Sort, | ||
| 0, 0, 0.0, 0, 0 | ||
| ) | ||
| val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) | ||
| val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
| assert(splits.length === 1) | ||
| assert(splits(0) === 1.0) | ||
| assert(splits === Array(1.0)) | ||
| } | ||
|
|
||
| // find splits for constant feature | ||
| { | ||
| val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, | ||
| Map(), Set(), | ||
| Array(3), Gini, QuantileStrategy.Sort, | ||
| 0, 0, 0.0, 0, 0 | ||
| ) | ||
| val featureSamples = Array(0, 0, 0).map(_.toDouble) | ||
| val featureSamplesEmpty = Array.empty[Double] | ||
| val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
| assert(splits === Array[Double]()) | ||
| val splitsEmpty = | ||
|
||
| RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) | ||
| assert(splitsEmpty === Array[Double]()) | ||
| } | ||
| } | ||
|
|
||
| test("train with constant features") { | ||
|
||
| val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) | ||
| val data = Array.fill(5)(lp) | ||
| val rdd = sc.parallelize(data) | ||
| val strategy = new OldStrategy( | ||
| OldAlgo.Classification, | ||
| Gini, | ||
| maxDepth = 2, | ||
| numClasses = 2, | ||
| maxBins = 100, | ||
| categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) | ||
| val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) | ||
| assert(tree.rootNode.impurity === -1.0) | ||
| assert(tree.depth === 0) | ||
| assert(tree.rootNode.prediction === lp.label) | ||
| } | ||
|
|
||
| test("Multiclass classification with unordered categorical features: split calculations") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of possible bins should be
valueCounts.length, and the number of possible splits should therefore bevalueCounts.length - 1.