From d94dc68a8c1a5c082cf3de6c7e4d429bfd24d817 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 14:50:41 +0800 Subject: [PATCH 1/5] [SPARK-19852][PYSPARK][ML] Update Python API for StringIndexer setHandleInvalid This PR reflect the changes made in SPARK-17498 on pyspark to support a new option 'keep' in StringIndexer to handle unseen labels Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92f8549e9cb9e..3ab9ab29c02a7 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1917,8 +1917,7 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] + >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), + ... Row(id=2, label="e")], 2) + >>> dfKeep= spark.createDataFrame(testData2) + >>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) + >>> itdKeep = inverter.transform(tdKeep) + >>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (6, 'd'), (6, 'e')] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) @@ -1955,6 +1962,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " + + "Options are 'skip' (filter out rows with unseen labels), " + + "error (throw an error), or 'keep' (put unseen labels in a special " + + "additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): """ @@ -1979,6 +1991,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): def _create_model(self, java_model): return StringIndexerModel(java_model) + @since("2.2.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("2.2.0") + def getHandleInvalid(self): + """ + Gets the value of :py:attr:`handleInvalid` or its default value. + """ + return self.getOrDefault(self.handleInvalid) + class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ From f1d9bcb3d615444a3c326f0b6b0f7999edecdf4f Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 15:21:47 +0800 Subject: [PATCH 2/5] fix compilation issues Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3ab9ab29c02a7..8b72cd6dece4c 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1938,9 +1938,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), ... Row(id=2, label="e")], 2) >>> dfKeep= spark.createDataFrame(testData2) - >>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) - >>> itdKeep = inverter.transform(tdKeep) - >>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]), + >>> tdK = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) + >>> itdK = inverter.transform(tdK) + >>> sorted(set([(i[0], str(i[1])) for i in itdK.select(itdK.id, itdK.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (6, 'd'), (6, 'e')] >>> stringIndexerPath = temp_path + "/string-indexer" @@ -1967,6 +1967,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja "error (throw an error), or 'keep' (put unseen labels in a special " + "additional bucket, at index numLabels).", typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): """ From affeeb770faf8fc3b5dc924e8e23704dee3f21ba Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 16:20:14 +0800 Subject: [PATCH 3/5] doctest Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 8b72cd6dece4c..dbebcb184317c 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1938,8 +1938,10 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), ... Row(id=2, label="e")], 2) >>> dfKeep= spark.createDataFrame(testData2) - >>> tdK = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) - >>> itdK = inverter.transform(tdK) + >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) + >>> tdK = modelKeep.transform(dfKeep) + >>> itdK = IndexToString(inputCol="indexed", outputCol="label2", + ... labels=modelKeep.labels).transform(tdK) >>> sorted(set([(i[0], str(i[1])) for i in itdK.select(itdK.id, itdK.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (6, 'd'), (6, 'e')] From b4bb765a672d53d7fdd8b0378c7a5762ec881078 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 16:50:09 +0800 Subject: [PATCH 4/5] update doctest Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index dbebcb184317c..f087af4cb8e35 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1940,11 +1940,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> dfKeep= spark.createDataFrame(testData2) >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) >>> tdK = modelKeep.transform(dfKeep) - >>> itdK = IndexToString(inputCol="indexed", outputCol="label2", - ... labels=modelKeep.labels).transform(tdK) - >>> sorted(set([(i[0], str(i[1])) for i in itdK.select(itdK.id, itdK.label2).collect()]), + >>> sorted(set([(i[0], i[1]) for i in tdK.select(tdK.id, tdK.indexed).collect()]), ... key=lambda x: x[0]) - [(0, 'a'), (6, 'd'), (6, 'e')] + [(0, 0.0), (1, 3.0), (2, 3.0)] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) From 1d2f28f2449691fe7efe3e7dfbfe577ad8a56c1f Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 17 Mar 2017 10:55:15 +0800 Subject: [PATCH 5/5] include changes made by SPARK-11569 Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f087af4cb8e35..b52d5bb773580 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1936,7 +1936,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), - ... Row(id=2, label="e")], 2) + ... Row(id=2, label=None)], 2) >>> dfKeep= spark.createDataFrame(testData2) >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) >>> tdK = modelKeep.transform(dfKeep) @@ -1962,10 +1962,10 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja .. versionadded:: 1.4.0 """ - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " + - "Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special " + - "additional bucket, at index numLabels).", + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", typeConverter=TypeConverters.toString) @keyword_only