Skip to content
Closed
98 changes: 97 additions & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula',
'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer',
'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel']
'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel',
'ChiSqSelector', 'ChiSqSelectorModel']


@inherit_doc
Expand Down Expand Up @@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel):
"""


@inherit_doc
class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol):
"""
.. note:: Experimental

Chi-Squared feature selection, which selects categorical features to use for predicting a
categorical label.

>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame(
... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],
... ["features", "label"])
>>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
>>> model = selector.fit(df)
>>> model.transform(df).head().selectedFeatures
DenseVector([1.0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this example, I'd just test for the first row using head()

>>> model.selectedFeatures
[3]

.. versionadded:: 2.0.0
"""

# a placeholder to make it appear in the generated doc
numTopFeatures = \
Param(Params._dummy(), "numTopFeatures",
"Number of features that selector will select, ordered by statistics value " +
"descending. If the number of features is < numTopFeatures, then this will select " +
"all features.")

@keyword_only
def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
"""
__init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label")
"""
super(ChiSqSelector, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
self.numTopFeatures = \
Param(self, "numTopFeatures",
"Number of features that selector will select, ordered by statistics value " +
"descending. If the number of features is < numTopFeatures, then this will " +
"select all features.")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
labelCol="labels"):
"""
setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\
labelCol="labels")
Sets params for this ChiSqSelector.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

@since("2.0.0")
def setNumTopFeatures(self, value):
"""
Sets the value of :py:attr:`numTopFeatures`.
"""
self._paramMap[self.numTopFeatures] = value
return self

@since("2.0.0")
def getNumTopFeatures(self):
"""
Gets the value of numTopFeatures or its default value.
"""
return self.getOrDefault(self.numTopFeatures)

def _create_model(self, java_model):
return ChiSqSelectorModel(java_model)


class ChiSqSelectorModel(JavaModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is loadable and saveable in Java, I don't see us doing this elsewhere in ml/ yet (although we do it in mllib/) but do we maybe want to use the JavaLoader & JavaSaveable base classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model persistence is important in PySpark, but there is no need to add it in this PR. @yanboliang has a JIRA for adding pipeline persistence in PySpark: https://issues.apache.org/jira/browse/SPARK-11939

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add the selectedFeatures method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

"""
.. note:: Experimental

Model fitted by ChiSqSelector.

.. versionadded:: 2.0.0
"""

@property
@since("2.0.0")
def selectedFeatures(self):
"""
List of indices to select (filter). Must be ordered asc.
"""
return self._call_java("selectedFeatures")


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
Expand Down