Skip to content

Commit b7a47e0

Browse files
committed
Implement a Chi-Squared test statistic option for measuring split quality when training decision trees
1 parent 6347ff5 commit b7a47e0

File tree

11 files changed

+277
-12
lines changed

11 files changed

+277
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ private[spark] class DTStatsAggregator(
3737
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
3838
case Gini => new GiniAggregator(metadata.numClasses)
3939
case Entropy => new EntropyAggregator(metadata.numClasses)
40+
case ChiSquared => new ChiSquaredAggregator(metadata.numClasses)
4041
case Variance => new VarianceAggregator()
4142
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
4243
}

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel
2929
import org.apache.spark.ml.tree._
3030
import org.apache.spark.ml.util.Instrumentation
3131
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
32-
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
32+
import org.apache.spark.mllib.tree.impurity.{Impurity, ImpurityCalculator}
3333
import org.apache.spark.mllib.tree.model.ImpurityStats
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.storage.StorageLevel
@@ -657,14 +657,32 @@ private[spark] object RandomForest extends Logging {
657657
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
658658
val rightImpurity = rightImpurityCalculator.calculate()
659659

660-
val leftWeight = leftCount / totalCount.toDouble
661-
val rightWeight = rightCount / totalCount.toDouble
660+
val gain = metadata.impurity match {
661+
case imp if (imp.isTestStatistic) =>
662+
// For split quality measures based on a test-statistic, run the test on the
663+
// left and right sub-populations to get a p-value for the null hypothesis
664+
val pval = imp.calculate(leftImpurityCalculator, rightImpurityCalculator)
665+
// Transform the test statistic p-val into a larger-is-better gain value
666+
Impurity.pValToGain(pval)
667+
668+
case _ =>
669+
// Default purity-gain logic:
670+
// measure the weighted decrease in impurity from parent to the left and right
671+
val leftWeight = leftCount / totalCount.toDouble
672+
val rightWeight = rightCount / totalCount.toDouble
673+
674+
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
675+
}
662676

663-
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
677+
// If the impurity being used is a test statistic p-val, apply a standard transform into
678+
// a larger-is-better gain value for the minimum-gain threshold
679+
val minGain =
680+
if (metadata.impurity.isTestStatistic) Impurity.pValToGain(metadata.minInfoGain)
681+
else metadata.minInfoGain
664682

665683
// if information gain doesn't satisfy minimum information gain,
666684
// then this split is invalid, return invalid information gain stats.
667-
if (gain < metadata.minInfoGain) {
685+
if (gain < minGain) {
668686
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
669687
}
670688

mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared._
2525
import org.apache.spark.ml.util.SchemaUtils
2626
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
27-
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
27+
import org.apache.spark.mllib.tree.impurity.{ChiSquared => OldChiSquared, Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
2828
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
2929
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
3030

@@ -185,7 +185,7 @@ private[ml] trait TreeClassifierParams extends Params {
185185

186186
/**
187187
* Criterion used for information gain calculation (case-insensitive).
188-
* Supported: "entropy" and "gini".
188+
* Supported: "entropy", "gini", "chisquared".
189189
* (default = gini)
190190
* @group param
191191
*/
@@ -207,6 +207,7 @@ private[ml] trait TreeClassifierParams extends Params {
207207
getImpurity match {
208208
case "entropy" => OldEntropy
209209
case "gini" => OldGini
210+
case "chisquared" => OldChiSquared
210211
case _ =>
211212
// Should never happen because of check in setter method.
212213
throw new RuntimeException(
@@ -217,7 +218,8 @@ private[ml] trait TreeClassifierParams extends Params {
217218

218219
private[ml] object TreeClassifierParams {
219220
// These options should be lowercase.
220-
final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
221+
final val supportedImpurities: Array[String] = Array("entropy", "gini", "chisquared")
222+
.map(_.toLowerCase)
221223
}
222224

223225
private[ml] trait DecisionTreeClassifierParams

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
2323
import org.apache.spark.annotation.Since
2424
import org.apache.spark.mllib.tree.configuration.Algo._
2525
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
26-
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
26+
import org.apache.spark.mllib.tree.impurity.{ChiSquared, Entropy, Gini, Impurity, Variance}
2727

2828
/**
2929
* Stores all the configuration options for tree construction
@@ -140,7 +140,7 @@ class Strategy @Since("1.3.0") (
140140
require(numClasses >= 2,
141141
s"DecisionTree Strategy for Classification must have numClasses >= 2," +
142142
s" but numClasses = $numClasses.")
143-
require(Set(Gini, Entropy).contains(impurity),
143+
require(Set(Gini, Entropy, ChiSquared).contains(impurity),
144144
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
145145
s" Valid settings: Gini, Entropy")
146146
case Regression =>
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
21+
22+
/**
23+
* :: Experimental ::
24+
* Class for calculating [[https://en.wikipedia.org/wiki/Chi-squared_test chi-squared]]
25+
* during binary classification.
26+
*/
27+
@Since("2.0.0")
28+
@Experimental
29+
object ChiSquared extends Impurity {
30+
private object CSTest extends org.apache.commons.math3.stat.inference.ChiSquareTest()
31+
32+
/**
33+
* Get this impurity instance.
34+
* This is useful for passing impurity parameters to a Strategy in Java.
35+
*/
36+
@Since("1.1.0")
37+
def instance: this.type = this
38+
39+
/**
40+
* :: DeveloperApi ::
41+
* Placeholding definition of classification-based purity.
42+
* @param counts Array[Double] with counts for each label
43+
* @param totalCount sum of counts for all labels
44+
* @return This method will throw an exception for [[ChiSquared]]
45+
*/
46+
@Since("1.1.0")
47+
@DeveloperApi
48+
override def calculate(counts: Array[Double], totalCount: Double): Double =
49+
throw new UnsupportedOperationException("ChiSquared.calculate")
50+
51+
/**
52+
* :: DeveloperApi ::
53+
* Placeholding definition of regression-based purity.
54+
* @param count number of instances
55+
* @param sum sum of labels
56+
* @param sumSquares summation of squares of the labels
57+
* @return This method will throw an exception for [[ChiSquared]]
58+
*/
59+
@Since("1.0.0")
60+
@DeveloperApi
61+
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
62+
throw new UnsupportedOperationException("ChiSquared.calculate")
63+
64+
/**
65+
* :: DeveloperApi ::
66+
* Chi-squared p-values from [[ImpurityCalculator]] for left and right split populations
67+
* @param calcL impurity calculator for the left split population
68+
* @param calcR impurity calculator for the right split population
69+
* @return The p-value for the chi squared null hypothesis; that left and right split populations
70+
* represent the same distribution of categorical values
71+
*/
72+
@Since("2.0.0")
73+
@DeveloperApi
74+
override def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = {
75+
CSTest.chiSquareTest(
76+
Array(
77+
calcL.stats.map(_.toLong),
78+
calcR.stats.map(_.toLong)
79+
)
80+
)
81+
}
82+
83+
/**
84+
* :: DeveloperApi ::
85+
* Determine if this impurity measure is a test-statistic measure (true for Chi-squared)
86+
* @return For [[ChiSquared]] will return true
87+
*/
88+
@Since("2.0.0")
89+
@DeveloperApi
90+
override def isTestStatistic: Boolean = true
91+
}
92+
93+
/**
94+
* Class for updating views of a vector of sufficient statistics,
95+
* in order to compute impurity from a sample.
96+
* Note: Instances of this class do not hold the data; they operate on views of the data.
97+
* @param numClasses Number of classes for label.
98+
*/
99+
private[spark] class ChiSquaredAggregator(numClasses: Int)
100+
extends ImpurityAggregator(numClasses) with Serializable {
101+
102+
/**
103+
* Update stats for one (node, feature, bin) with the given label.
104+
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
105+
* @param offset Start index of stats for this (node, feature, bin).
106+
*/
107+
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
108+
allStats(offset + label.toInt) += instanceWeight
109+
}
110+
111+
/**
112+
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
113+
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
114+
* @param offset Start index of stats for this (node, feature, bin).
115+
*/
116+
def getCalculator(allStats: Array[Double], offset: Int): ChiSquaredCalculator = {
117+
new ChiSquaredCalculator(allStats.view(offset, offset + statsSize).toArray)
118+
}
119+
}
120+
121+
/**
122+
* Stores statistics for one (node, feature, bin) for calculating impurity.
123+
* This class stores its own data and is for a specific (node, feature, bin).
124+
* @param stats Array of sufficient statistics for a (node, feature, bin).
125+
*/
126+
private[spark] class ChiSquaredCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
127+
128+
/**
129+
* Make a deep copy of this [[ImpurityCalculator]].
130+
*/
131+
def copy: ChiSquaredCalculator = new ChiSquaredCalculator(stats.clone())
132+
133+
/**
134+
* Calculate the impurity from the stored sufficient statistics.
135+
*/
136+
def calculate(): Double = 1.0
137+
138+
/**
139+
* Number of data points accounted for in the sufficient statistics.
140+
*/
141+
def count: Long = stats.sum.toLong
142+
143+
/**
144+
* Prediction which should be made based on the sufficient statistics.
145+
*/
146+
def predict: Double =
147+
if (count == 0) 0 else indexOfLargestArrayElement(stats)
148+
149+
/**
150+
* Probability of the label given by [[predict]].
151+
*/
152+
override def prob(label: Double): Double = {
153+
val lbl = label.toInt
154+
require(lbl < stats.length,
155+
s"ChiSquaredCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
156+
require(lbl >= 0, "ChiSquaredImpurity does not support negative labels")
157+
val cnt = count
158+
if (cnt == 0) 0 else (stats(lbl) / cnt)
159+
}
160+
161+
override def toString: String = s"ChiSquaredCalculator(stats = [${stats.mkString(", ")}])"
162+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ private[mllib] object Impurities {
2626
case "gini" => Gini
2727
case "entropy" => Entropy
2828
case "variance" => Variance
29+
case "chisquared" => ChiSquared
2930
case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
3031
}
3132

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,50 @@ trait Impurity extends Serializable {
5252
@Since("1.0.0")
5353
@DeveloperApi
5454
def calculate(count: Double, sum: Double, sumSquares: Double): Double
55+
56+
/**
57+
* :: DeveloperApi ::
58+
* Compute a test-statistic p-value quality measure from left and right split populations
59+
* @param calcL impurity calculator for the left split population
60+
* @param calcR impurity calculator for the right split population
61+
* @return The p-value for the null hypothesis; that left and right split populations
62+
* represent the same distribution
63+
* @note Unless overridden this method will fail with an exception, for backward compatability
64+
*/
65+
@Since("2.0.0")
66+
@DeveloperApi
67+
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double =
68+
throw new UnsupportedOperationException("Impurity.calculate")
69+
70+
/**
71+
* :: DeveloperApi ::
72+
* Determine if this impurity measure is a test-statistic measure
73+
* @return True if this is a split quality measure based on a test statistic (i.e. returns a
74+
* p-value) or false otherwise.
75+
* @note Unless overridden this method returns false by default, for backward compatability
76+
*/
77+
@Since("2.0.0")
78+
@DeveloperApi
79+
def isTestStatistic: Boolean = false
80+
}
81+
82+
/**
83+
* :: DeveloperApi ::
84+
* Utility functions for Impurity measures
85+
*/
86+
@Since("2.0.0")
87+
@DeveloperApi
88+
object Impurity {
89+
/**
90+
* :: DeveloperApi ::
91+
* Convert a test-statistic p-value into a "larger-is-better" gain value.
92+
* @param pval The test statistic p-value
93+
* @return The negative logarithm of the p-value. Any p-values smaller than 10^-20 are clipped
94+
* to 10^-20 to prevent arithmetic errors
95+
*/
96+
@Since("2.0.0")
97+
@DeveloperApi
98+
def pValToGain(pval: Double): Double = -math.log(math.max(1e-20, pval))
5599
}
56100

57101
/**

mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,41 @@ class DecisionTreeClassifierSuite
236236
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
237237
}
238238

239+
test("split quality using chi-squared and minimum gain") {
240+
// Generate a data set where the 1st feature is useful and the others are noise
241+
val features = Vector.fill(200) {
242+
Array.fill(3) { scala.util.Random.nextInt(2).toDouble }
243+
}
244+
val labels = features.map { fv =>
245+
LabeledPoint(if (fv(0) == 1.0) 1.0 else 0.0, Vectors.dense(fv))
246+
}
247+
val rdd = sc.parallelize(labels)
248+
249+
// two-class learning problem
250+
val numClasses = 2
251+
// all binary features
252+
val catFeatures = Map(Vector.tabulate(features.head.length) { j => (j, 2) } : _*)
253+
254+
// Chi-squared split quality with a p-value threshold of 0.01 should allow
255+
// only the first feature to be used since the others are uncorrelated noise
256+
val train: DataFrame = TreeTests.setMetadata(rdd, catFeatures, numClasses)
257+
val dt = new DecisionTreeClassifier()
258+
.setImpurity("chisquared")
259+
.setMaxDepth(5)
260+
.setMinInfoGain(0.01)
261+
val treeModel = dt.fit(train)
262+
263+
// The tree should use exactly one of the 3 features: featue(0)
264+
val featImps = treeModel.featureImportances
265+
assert(treeModel.depth === 1)
266+
assert(featImps.size === 3)
267+
assert(featImps(0) === 1.0)
268+
assert(featImps(1) === 0.0)
269+
assert(featImps(2) === 0.0)
270+
271+
compareAPIs(rdd, dt, catFeatures, numClasses)
272+
}
273+
239274
test("predictRaw and predictProbability") {
240275
val rdd = continuousDataPointsForMulticlassRDD
241276
val dt = new DecisionTreeClassifier()

project/MimaExcludes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,8 @@ object MimaExcludes {
15111511
"org.apache.spark.mllib.tree.impurity.Gini.calculate"),
15121512
ProblemFilters.exclude[IncompatibleMethTypeProblem](
15131513
"org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
1514+
ProblemFilters.exclude[IncompatibleMethTypeProblem](
1515+
"org.apache.spark.mllib.tree.impurity.ChiSquared.calculate"),
15141516
ProblemFilters.exclude[IncompatibleMethTypeProblem](
15151517
"org.apache.spark.mllib.tree.impurity.Variance.calculate")
15161518
) ++

python/pyspark/ml/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ class TreeClassifierParams(object):
451451
452452
.. versionadded:: 1.4.0
453453
"""
454-
supportedImpurities = ["entropy", "gini"]
454+
supportedImpurities = ["entropy", "gini", "chisquared"]
455455

456456
impurity = Param(Params._dummy(), "impurity",
457457
"Criterion used for information gain calculation (case-insensitive). " +

0 commit comments

Comments
 (0)