Skip to content

Commit 414e249

Browse files
committed
And switch to using handleInvalid instead of skipInvalid
1 parent 1e53f9b commit 414e249

File tree

4 files changed

+15
-14
lines changed

4 files changed

+15
-14
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.util.collection.OpenHashMap
3333
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
3434
*/
3535
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
36-
with HasSkipInvalid {
36+
with HasHandleInvalid {
3737

3838
/** Validates and transforms the input schema. */
3939
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -66,8 +66,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
6666
def this() = this(Identifiable.randomUID("strIdx"))
6767

6868
/** @group setParam */
69-
def setSkipInvalid(value: Boolean): this.type = set(skipInvalid, value)
70-
setDefault(skipInvalid, false)
69+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
70+
setDefault(handleInvalid, "error")
7171

7272
/** @group setParam */
7373
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -115,8 +115,8 @@ class StringIndexerModel private[ml] (
115115
}
116116

117117
/** @group setParam */
118-
def setSkipInvalid(value: Boolean): this.type = set(skipInvalid, value)
119-
setDefault(skipInvalid, false)
118+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
119+
setDefault(handleInvalid, "error")
120120

121121
/** @group setParam */
122122
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -143,13 +143,14 @@ class StringIndexerModel private[ml] (
143143
val metadata = NominalAttribute.defaultAttr
144144
.withName(outputColName).withValues(labels).toMetadata()
145145
// If we are skipping invalid records, filter them out.
146-
val filteredDataset = if (getSkipInvalid) {
147-
val filterer = udf { label: String =>
148-
labelToIndex.contains(label)
146+
val filteredDataset = (getHandleInvalid) match {
147+
case "skip" => {
148+
val filterer = udf { label: String =>
149+
labelToIndex.contains(label)
150+
}
151+
dataset.where(filterer(dataset($(inputCol))))
149152
}
150-
dataset.where(filterer(dataset($(inputCol))))
151-
} else {
152-
dataset
153+
case _ => dataset
153154
}
154155
filteredDataset.select(col("*"),
155156
indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ private[shared] object SharedParamsCodeGen {
5454
isValid = "ParamValidators.gtEq(1)"),
5555
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
5656
ParamDesc[String]("handleInvalid", "how to handle invalid entries",
57-
isValid = "ParamValidators.inArray(List(\"skip\", \"error\"))"),
57+
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
5858
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
5959
" before fitting the model.", Some("true")),
6060
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ private[ml] trait HasHandleInvalid extends Params {
241241
* Param for how to handle invalid entries.
242242
* @group param
243243
*/
244-
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries", ParamValidators.inArray(List("skip", "error")))
244+
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries", ParamValidators.inArray(Array("skip", "error")))
245245

246246
/** @group getParam */
247247
final def getHandleInvalid: String = $(handleInvalid)

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
6666
val indexerSkipInvalid = new StringIndexer()
6767
.setInputCol("label")
6868
.setOutputCol("labelIndex")
69-
.setSkipInvalid(true)
69+
.setHandleInvalid("skip")
7070
.fit(df)
7171
// Verify that we skip the c record
7272
val transformed = indexerSkipInvalid.transform(df2)

0 commit comments

Comments
 (0)