From a11558ea329b174531e4f3c3e4d95f875fbc5f5d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 26 Nov 2015 22:23:16 +0800 Subject: [PATCH 1/4] add QuantileDiscretizer in Python --- python/pyspark/ml/feature.py | 88 +++++++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b02d41b52ab2..bb3d1c57afaf 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -33,7 +33,7 @@ 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel'] + 'Word2Vec', 'Word2VecModel', 'QuantileDiscretizer'] @inherit_doc @@ -2093,6 +2093,92 @@ class RFormulaModel(JavaModel): """ +@inherit_doc +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned + categorical features. The bin ranges are chosen by taking a sample of the data and dividing it + into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, + covering all real values. This attempts to find numBuckets partitions based on a sample of data, + but it may find fewer depending on the data sample values. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> discretizer = QuantileDiscretizer(inputCol="values", outputCol="buckets").setNumBuckets(3) + >>> bucketed = discretizer.fit(df).transform(df).collect() + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 1.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 + + .. versionadded:: 1.7.0 + """ + + # a placeholder to make it appear in the generated doc + numBuckets = \ + Param(Params._dummy(), "numBuckets", + "Maximum number of buckets (quantiles, or categories) into which data points are " + + "grouped. Must be >= 2.") + + @keyword_only + def __init__(self, numBuckets=None, inputCol=None, outputCol=None): + """ + __init__(self, numBuckets=None, inputCol=None, outputCol=None) + """ + super(QuantileDiscretizer, self).__init__() + self._java_obj = \ + self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", self.uid) + # Maximum number of buckets (quantiles, or categories) into which data points are grouped. + # Must be >= 2. + # default: 2 + self.numBuckets = \ + Param(self, "numBuckets", + "Maximum number of buckets (quantiles, or categories) into which data points " + + "are grouped. Must be >= 2.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + class QuantileDiscretizerModel(JavaModel): + def getSplits(self): + return self._call_java("getSplits") + + @keyword_only + @since("1.7.0") + def setParams(self, numBuckets=None, inputCol=None, outputCol=None): + """ + setParams(self, numBuckets=None, inputCol=None, outputCol=None) + Sets params for this QuantileDiscretizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.7.0") + def setNumBuckets(self, value): + """ + Sets the value of :py:attr:`numBuckets`. + """ + self._paramMap[self.numBuckets] = value + return self + + @since("1.7.0") + def getNumBuckets(self): + """ + Gets the value of numBuckets or its default value. + """ + return self.getOrDefault(self.numBuckets) + + def _create_model(self, java_model): + model = self.QuantileDiscretizerModel(java_model) + return Bucketizer(splits=model.getSplits(), + inputCol=self.getOrDefault("inputCol"), + outputCol=self.getOrDefault("outputCol")) + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 670821bd0adbb58483e453d1b220179659d254f3 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 26 Nov 2015 23:49:43 +0800 Subject: [PATCH 2/4] add ChiSqSelector in Python --- python/pyspark/ml/feature.py | 89 ++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index bb3d1c57afaf..8afd1c64e221 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2179,6 +2179,95 @@ def _create_model(self, java_model): outputCol=self.getOrDefault("outputCol")) +@inherit_doc +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol): + """ + .. note:: Experimental + + # Chi-Squared feature selection, which selects categorical features to use for predicting a + # categorical label. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], + ... ["features", "label"]) + >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.transform(df).collect()[0].selectedFeatures + DenseVector([1.0]) + >>> model.transform(df).collect()[1].selectedFeatures + DenseVector([0.0]) + >>> model.transform(df).collect()[2].selectedFeatures + DenseVector([0.1]) + + .. versionadded:: 1.7.0 + """ + + # a placeholder to make it appear in the generated doc + numTopFeatures = \ + Param(Params._dummy(), "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will select " + + "all features.") + + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + """ + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + """ + super(ChiSqSelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) + self.numTopFeatures = \ + Param(self, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will " + + "select all features.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.7.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels"): + """ + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels") + Sets params for this ChiSqSelector. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.7.0") + def setNumTopFeatures(self, value): + """ + Sets the value of :py:attr:`numTopFeatures`. + """ + self._paramMap[self.numTopFeatures] = value + return self + + @since("1.7.0") + def getNumTopFeatures(self): + """ + Gets the value of numTopFeatures or its default value. + """ + return self.getOrDefault(self.numTopFeatures) + + def _create_model(self, java_model): + return ChiSqSelectorModel(java_model) + + +class ChiSqSelectorModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by ChiSqSelector. + + .. versionadded:: 1.7.0 + """ + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 05f3eddcf38d9211f65f2af2227cb9b249955208 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 26 Nov 2015 23:54:02 +0800 Subject: [PATCH 3/4] add class exports --- python/pyspark/ml/feature.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 8afd1c64e221..1905db20c522 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -33,7 +33,8 @@ 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel', 'QuantileDiscretizer'] + 'Word2Vec', 'Word2VecModel', 'QuantileDiscretizer', 'ChiSqSelector', + 'ChiSqSelectorModel'] @inherit_doc From 3a33327122ae94d59403d807255273180528d9a9 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 3 Dec 2015 18:33:45 +0800 Subject: [PATCH 4/4] add java competible --- .../scala/org/apache/spark/ml/feature/Bucketizer.scala | 9 ++++++++- python/pyspark/ml/feature.py | 4 +++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 324353a96afb..3fad31629908 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -19,8 +19,10 @@ package org.apache.spark.ml.feature import java.{util => ju} +import scala.collection.JavaConverters._ + import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -56,6 +58,11 @@ final class Bucketizer(override val uid: String) "otherwise, values outside the splits specified will be treated as errors.", Bucketizer.checkSplits) + /** + * Method for calling from Python code (PySpark). + */ + def getJavaSplits: java.util.List[Double] = $(splits).toSeq.asJava + /** @group getParam */ def getSplits: Array[Double] = $(splits) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1905db20c522..30c51e6b07bf 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2144,9 +2144,11 @@ def __init__(self, numBuckets=None, inputCol=None, outputCol=None): kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + # The inner class is used as an extractor that extracts splits from the JavaModel generated by + # QuantileDiscretizer, then constructs Bucketizer with the extracted splits. class QuantileDiscretizerModel(JavaModel): def getSplits(self): - return self._call_java("getSplits") + return self._call_java("getJavaSplits") @keyword_only @since("1.7.0")