-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-11923][ML] Python API for ml.feature.ChiSqSelector #10186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a11558e
670821b
05f3edd
3789867
3a33327
a5e72ad
f49e231
657a0d4
61f3827
aa9d40f
e276440
0bd1271
32cdbb0
223fdf4
83e7a90
3fca95e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,8 @@ | |
| 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', | ||
| 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', | ||
| 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', | ||
| 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] | ||
| 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', | ||
| 'ChiSqSelector', 'ChiSqSelectorModel'] | ||
|
|
||
|
|
||
| @inherit_doc | ||
|
|
@@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel): | |
| """ | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Chi-Squared feature selection, which selects categorical features to use for predicting a | ||
| categorical label. | ||
|
|
||
| >>> from pyspark.mllib.linalg import Vectors | ||
| >>> df = sqlContext.createDataFrame( | ||
| ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), | ||
| ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), | ||
| ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], | ||
| ... ["features", "label"]) | ||
| >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") | ||
| >>> model = selector.fit(df) | ||
| >>> model.transform(df).head().selectedFeatures | ||
| DenseVector([1.0]) | ||
| >>> model.selectedFeatures | ||
| [3] | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| # a placeholder to make it appear in the generated doc | ||
| numTopFeatures = \ | ||
| Param(Params._dummy(), "numTopFeatures", | ||
| "Number of features that selector will select, ordered by statistics value " + | ||
| "descending. If the number of features is < numTopFeatures, then this will select " + | ||
| "all features.") | ||
|
|
||
| @keyword_only | ||
| def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): | ||
| """ | ||
| __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") | ||
| """ | ||
| super(ChiSqSelector, self).__init__() | ||
| self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) | ||
| self.numTopFeatures = \ | ||
| Param(self, "numTopFeatures", | ||
| "Number of features that selector will select, ordered by statistics value " + | ||
| "descending. If the number of features is < numTopFeatures, then this will " + | ||
| "select all features.") | ||
| kwargs = self.__init__._input_kwargs | ||
| self.setParams(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("2.0.0") | ||
| def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, | ||
| labelCol="labels"): | ||
| """ | ||
| setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ | ||
| labelCol="labels") | ||
| Sets params for this ChiSqSelector. | ||
| """ | ||
| kwargs = self.setParams._input_kwargs | ||
| return self._set(**kwargs) | ||
|
|
||
| @since("2.0.0") | ||
| def setNumTopFeatures(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`numTopFeatures`. | ||
| """ | ||
| self._paramMap[self.numTopFeatures] = value | ||
| return self | ||
|
|
||
| @since("2.0.0") | ||
| def getNumTopFeatures(self): | ||
| """ | ||
| Gets the value of numTopFeatures or its default value. | ||
| """ | ||
| return self.getOrDefault(self.numTopFeatures) | ||
|
|
||
| def _create_model(self, java_model): | ||
| return ChiSqSelectorModel(java_model) | ||
|
|
||
|
|
||
| class ChiSqSelectorModel(JavaModel): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This model is loadable and saveable in Java, I don't see us doing this elsewhere in ml/ yet (although we do it in mllib/) but do we maybe want to use the JavaLoader & JavaSaveable base classes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Model persistence is important in PySpark, but there is no need to add it in this PR. @yanboliang has a JIRA for adding pipeline persistence in PySpark: https://issues.apache.org/jira/browse/SPARK-11939
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool :)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure |
||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Model fitted by ChiSqSelector. | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def selectedFeatures(self): | ||
| """ | ||
| List of indices to select (filter). Must be ordered asc. | ||
| """ | ||
| return self._call_java("selectedFeatures") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| from pyspark.context import SparkContext | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this example, I'd just test for the first row using
head()