@@ -33,7 +33,7 @@ import org.apache.spark.util.collection.OpenHashMap
3333 * Base trait for [[StringIndexer ]] and [[StringIndexerModel ]].
3434 */
3535private [feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
36- with HasSkipInvalid {
36+ with HasHandleInvalid {
3737
3838 /** Validates and transforms the input schema. */
3939 protected def validateAndTransformSchema (schema : StructType ): StructType = {
@@ -66,8 +66,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
6666 def this () = this (Identifiable .randomUID(" strIdx" ))
6767
6868 /** @group setParam */
69- def setSkipInvalid (value : Boolean ): this .type = set(skipInvalid , value)
70- setDefault(skipInvalid, false )
69+ def setHandleInvalid (value : String ): this .type = set(handleInvalid , value)
70+ setDefault(handleInvalid, " error " )
7171
7272 /** @group setParam */
7373 def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -115,8 +115,8 @@ class StringIndexerModel private[ml] (
115115 }
116116
117117 /** @group setParam */
118- def setSkipInvalid (value : Boolean ): this .type = set(skipInvalid , value)
119- setDefault(skipInvalid, false )
118+ def setHandleInvalid (value : String ): this .type = set(handleInvalid , value)
119+ setDefault(handleInvalid, " error " )
120120
121121 /** @group setParam */
122122 def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -143,13 +143,14 @@ class StringIndexerModel private[ml] (
143143 val metadata = NominalAttribute .defaultAttr
144144 .withName(outputColName).withValues(labels).toMetadata()
145145 // If we are skipping invalid records, filter them out.
146- val filteredDataset = if (getSkipInvalid) {
147- val filterer = udf { label : String =>
148- labelToIndex.contains(label)
146+ val filteredDataset = (getHandleInvalid) match {
147+ case " skip" => {
148+ val filterer = udf { label : String =>
149+ labelToIndex.contains(label)
150+ }
151+ dataset.where(filterer(dataset($(inputCol))))
149152 }
150- dataset.where(filterer(dataset($(inputCol))))
151- } else {
152- dataset
153+ case _ => dataset
153154 }
154155 filteredDataset.select(col(" *" ),
155156 indexer(dataset($(inputCol)).cast(StringType )).as(outputColName, metadata))
0 commit comments