@@ -2490,7 +2490,8 @@ def setParams(self, inputCols=None, outputCol=None):
24902490
24912491
24922492@inherit_doc
2493- class VectorIndexer (JavaEstimator , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
2493+ class VectorIndexer (JavaEstimator , HasInputCol , HasOutputCol , HasHandleInvalid , JavaMLReadable ,
2494+ JavaMLWritable ):
24942495 """
24952496 Class for indexing categorical feature columns in a dataset of `Vector`.
24962497
@@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
25252526 do not recompute.
25262527 - Specify certain features to not index, either via a parameter or via existing metadata.
25272528 - Add warning if a categorical feature has only 1 category.
2528- - Add option for allowing unknown categories.
25292529
25302530 >>> from pyspark.ml.linalg import Vectors
25312531 >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
@@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
25562556 True
25572557 >>> loadedModel.categoryMaps == model.categoryMaps
25582558 True
2559+ >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])
2560+ >>> indexer.getHandleInvalid()
2561+ 'error'
2562+ >>> model3 = indexer.setHandleInvalid("skip").fit(df)
2563+ >>> model3.transform(dfWithInvalid).count()
2564+ 0
2565+ >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)
2566+ >>> model4.transform(dfWithInvalid).head().indexed
2567+ DenseVector([2.0, 1.0])
25592568
25602569 .. versionadded:: 1.4.0
25612570 """
@@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
25652574 "(>= 2). If a feature is found to have > maxCategories values, then " +
25662575 "it is declared continuous." , typeConverter = TypeConverters .toInt )
25672576
2577+ handleInvalid = Param (Params ._dummy (), "handleInvalid" , "How to handle invalid data " +
2578+ "(unseen labels or NULL values). Options are 'skip' (filter out " +
2579+ "rows with invalid data), 'error' (throw an error), or 'keep' (put " +
2580+ "invalid data in a special additional bucket, at index of the number " +
2581+ "of categories of the feature)." ,
2582+ typeConverter = TypeConverters .toString )
2583+
25682584 @keyword_only
2569- def __init__ (self , maxCategories = 20 , inputCol = None , outputCol = None ):
2585+ def __init__ (self , maxCategories = 20 , inputCol = None , outputCol = None , handleInvalid = "error" ):
25702586 """
2571- __init__(self, maxCategories=20, inputCol=None, outputCol=None)
2587+ __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error" )
25722588 """
25732589 super (VectorIndexer , self ).__init__ ()
25742590 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.VectorIndexer" , self .uid )
2575- self ._setDefault (maxCategories = 20 )
2591+ self ._setDefault (maxCategories = 20 , handleInvalid = "error" )
25762592 kwargs = self ._input_kwargs
25772593 self .setParams (** kwargs )
25782594
25792595 @keyword_only
25802596 @since ("1.4.0" )
2581- def setParams (self , maxCategories = 20 , inputCol = None , outputCol = None ):
2597+ def setParams (self , maxCategories = 20 , inputCol = None , outputCol = None , handleInvalid = "error" ):
25822598 """
2583- setParams(self, maxCategories=20, inputCol=None, outputCol=None)
2599+ setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error" )
25842600 Sets params for this VectorIndexer.
25852601 """
25862602 kwargs = self ._input_kwargs
0 commit comments