Skip to content

Commit bb209c8

Browse files
committed
address comments
1 parent 0bd9f66 commit bb209c8

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
8484

8585
require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
8686
(!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
87-
"Only allow to set either inputCol/outputCol, or inputCols/outputCols"
87+
"StringIndexer only supports setting either inputCol/outputCol or inputCols/outputCols."
8888
)
8989

9090
if (isSet(inputCol)) {

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,29 @@ class StringIndexerSuite
3737
val modelWithoutUid = new StringIndexerModel(Array("a", "b"))
3838
ParamsSuite.checkParams(model)
3939
ParamsSuite.checkParams(modelWithoutUid)
40+
41+
val stringIndexerSingleCol = new StringIndexer()
42+
.setInputCol("in").setOutputCol("out")
43+
val inOutCols1 = stringIndexerSingleCol.getInOutCols
44+
assert(inOutCols1._1 === Array("in"))
45+
assert(inOutCols1._2 === Array("out"))
46+
47+
val stringIndexerMultiCol = new StringIndexer()
48+
.setInputCols(Array("in1", "in2")).setOutputCols(Array("out1", "out2"))
49+
val inOutCols2 = stringIndexerMultiCol.getInOutCols
50+
assert(inOutCols2._1 === Array("in1", "in2"))
51+
assert(inOutCols2._2 === Array("out1", "out2"))
52+
53+
intercept[IllegalArgumentException] {
54+
new StringIndexer().setInputCol("in").setOutputCols(Array("out1", "out2")).getInOutCols
55+
}
56+
intercept[IllegalArgumentException] {
57+
new StringIndexer().setInputCols(Array("in1", "in2")).setOutputCol("out1").getInOutCols
58+
}
59+
intercept[IllegalArgumentException] {
60+
new StringIndexer().setInputCols(Array("in1", "in2"))
61+
.setOutputCols(Array("out1", "out2", "out3")).getInOutCols
62+
}
4063
}
4164

4265
test("StringIndexer") {

0 commit comments

Comments
 (0)