diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92f8549e9cb9e..b52d5bb773580 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1917,8 +1917,7 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] + >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), + ... Row(id=2, label=None)], 2) + >>> dfKeep= spark.createDataFrame(testData2) + >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) + >>> tdK = modelKeep.transform(dfKeep) + >>> sorted(set([(i[0], i[1]) for i in tdK.select(tdK.id, tdK.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 3.0), (2, 3.0)] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) @@ -1955,6 +1962,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): """ @@ -1979,6 +1992,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): def _create_model(self, java_model): return StringIndexerModel(java_model) + @since("2.2.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("2.2.0") + def getHandleInvalid(self): + """ + Gets the value of :py:attr:`handleInvalid` or its default value. + """ + return self.getOrDefault(self.handleInvalid) + class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """