-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-19635][ML] DataFrame-based API for chi square test #17110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.stat | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} | ||
| import org.apache.spark.ml.util.SchemaUtils | ||
| import org.apache.spark.mllib.linalg.{Vectors => OldVectors} | ||
| import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} | ||
| import org.apache.spark.mllib.stat.{Statistics => OldStatistics} | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.functions.col | ||
|
|
||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * | ||
| * Chi-square hypothesis testing for categorical data. | ||
| * | ||
| * See <a href="http://en.wikipedia.org/wiki/Chi-squared_test">Wikipedia</a> for more information | ||
| * on the Chi-squared test. | ||
| */ | ||
| @Experimental | ||
| @Since("2.2.0") | ||
| object ChiSquare { | ||
|
|
||
| /** Used to construct output schema of tests */ | ||
| private case class ChiSquareResult( | ||
| pValues: Vector, | ||
| degreesOfFreedom: Array[Int], | ||
| statistics: Vector) | ||
|
|
||
| /** | ||
| * Conduct Pearson's independence test for every feature against the label across the input RDD. | ||
| * For each feature, the (feature, label) pairs are converted into a contingency matrix for which | ||
| * the Chi-squared statistic is computed. All label and feature values must be categorical. | ||
| * | ||
| * The null hypothesis is that the occurrence of the outcomes is statistically independent. | ||
| * | ||
| * @param dataset DataFrame of categorical labels and categorical features. | ||
| * Real-valued features will be treated as categorical for each distinct value. | ||
| * @param featuresCol Name of features column in dataset, of type `Vector` (`VectorUDT`) | ||
| * @param labelCol Name of label column in dataset, of any numerical type | ||
| * @return DataFrame containing the test result for every feature against the label. | ||
| * This DataFrame will contain a single Row with the following fields: | ||
| * - `pValues: Vector` | ||
| * - `degreesOfFreedom: Array[Int]` | ||
| * - `statistics: Vector` | ||
| * Each of these fields has one value per feature. | ||
| */ | ||
| @Since("2.2.0") | ||
| def test(dataset: DataFrame, featuresCol: String, labelCol: String): DataFrame = { | ||
| val spark = dataset.sparkSession | ||
| import spark.implicits._ | ||
|
|
||
| SchemaUtils.checkColumnType(dataset.schema, featuresCol, new VectorUDT) | ||
| SchemaUtils.checkNumericType(dataset.schema, labelCol) | ||
| val rdd = dataset.select(col(labelCol).cast("double"), col(featuresCol)).as[(Double, Vector)] | ||
| .rdd.map { case (label, features) => OldLabeledPoint(label, OldVectors.fromML(features)) } | ||
| val testResults = OldStatistics.chiSqTest(rdd) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be nice to optimize this in the future -- since we have schema, if the label and features have been converted to categorical, we can get the unique values right away instead of having to re-generate the maps for distinct labels and features
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely; feel free to make a JIRA for it. |
||
| val pValues: Vector = Vectors.dense(testResults.map(_.pValue)) | ||
| val degreesOfFreedom: Array[Int] = testResults.map(_.degreesOfFreedom) | ||
| val statistics: Vector = Vectors.dense(testResults.map(_.statistic)) | ||
| spark.createDataFrame(Seq(ChiSquareResult(pValues, degreesOfFreedom, statistics))) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.stat | ||
|
|
||
| import java.util.Random | ||
|
|
||
| import org.apache.spark.{SparkException, SparkFunSuite} | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.ml.util.TestingUtils._ | ||
| import org.apache.spark.mllib.stat.test.ChiSqTest | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
|
|
||
| class ChiSquareSuite | ||
| extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
|
|
||
| import testImplicits._ | ||
|
|
||
| test("test DataFrame of labeled points") { | ||
| // labels: 1.0 (2 / 6), 0.0 (4 / 6) | ||
| // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) | ||
| // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) | ||
| val data = Seq( | ||
| LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), | ||
| LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), | ||
| LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), | ||
| LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), | ||
| LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), | ||
| LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) | ||
| for (numParts <- List(2, 4, 6, 8)) { | ||
| val df = spark.createDataFrame(sc.parallelize(data, numParts)) | ||
| val chi = ChiSquare.test(df, "features", "label") | ||
| val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = | ||
| chi.select("pValues", "degreesOfFreedom", "statistics") | ||
| .as[(Vector, Array[Int], Vector)].head() | ||
| assert(pValues ~== Vectors.dense(0.6873, 0.6823) relTol 1e-4) | ||
| assert(degreesOfFreedom === Array(2, 3)) | ||
| assert(statistics ~== Vectors.dense(0.75, 1.5) relTol 1e-4) | ||
| } | ||
| } | ||
|
|
||
| test("large number of features (SPARK-3087)") { | ||
| // Test that the right number of results is returned | ||
| val numCols = 1001 | ||
| val sparseData = Array( | ||
| LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), | ||
| LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) | ||
| val df = spark.createDataFrame(sparseData) | ||
| val chi = ChiSquare.test(df, "features", "label") | ||
| val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = | ||
| chi.select("pValues", "degreesOfFreedom", "statistics") | ||
| .as[(Vector, Array[Int], Vector)].head() | ||
| assert(pValues.size === numCols) | ||
| assert(degreesOfFreedom.length === numCols) | ||
| assert(statistics.size === numCols) | ||
| assert(pValues(1000) !== null) // SPARK-3087 | ||
| } | ||
|
|
||
| test("fail on continuous features or labels") { | ||
| val tooManyCategories: Int = 100000 | ||
| assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + | ||
| "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") | ||
|
|
||
| val random = new Random(11L) | ||
| val continuousLabel = Seq.fill(tooManyCategories)( | ||
| LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) | ||
| withClue("ChiSquare should throw an exception when given a continuous-valued label") { | ||
| intercept[SparkException] { | ||
| val df = spark.createDataFrame(continuousLabel) | ||
| ChiSquare.test(df, "features", "label") | ||
| } | ||
| } | ||
| val continuousFeature = Seq.fill(tooManyCategories)( | ||
| LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) | ||
| withClue("ChiSquare should throw an exception when given continuous-valued features") { | ||
| intercept[SparkException] { | ||
| val df = spark.createDataFrame(continuousFeature) | ||
| ChiSquare.test(df, "features", "label") | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't chi square test work for binary type as well? or we don't want to support that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable, but let's do that in the future; this is already a lot more types than the RDD-based API supports.