Skip to content

Commit 16347f4

Browse files
committed
Add more test cases.
1 parent 8d1536a commit 16347f4

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,23 +150,17 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
150150
OldLabeledPoint(label, OldVectors.fromML(features))
151151
}
152152
val selector = new feature.ChiSqSelector()
153-
$(selectorType) match {
154-
case OldChiSqSelector.KBest =>
155-
selector.setNumTopFeatures($(numTopFeatures))
156-
case OldChiSqSelector.Percentile =>
157-
selector.setPercentile($(percentile))
158-
case OldChiSqSelector.FPR =>
159-
selector.setAlpha($(alpha))
160-
case errorType =>
161-
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
162-
}
153+
.setSelectorType($(selectorType))
154+
.setNumTopFeatures($(numTopFeatures))
155+
.setPercentile($(percentile))
156+
.setAlpha($(alpha))
163157
val model = selector.fit(input)
164158
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
165159
}
166160

167161
@Since("1.6.0")
168162
override def transformSchema(schema: StructType): StructType = {
169-
val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 == $(selectorType))
163+
val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
170164
otherPairs.foreach { case (_, paramName: String) =>
171165
if (isSet(getParam(paramName))) {
172166
logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")

mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
5050
.toDF("label", "data", "preFilteredData")
5151

5252
val selector = new ChiSqSelector()
53+
.setSelectorType("kbest")
5354
.setNumTopFeatures(1)
5455
.setFeaturesCol("data")
5556
.setLabelCol("label")
@@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
6061
assert(vec1 ~== vec2 absTol 1e-1)
6162
}
6263

63-
selector.setPercentile(0.34).fit(df).transform(df)
64-
.select("filtered", "preFilteredData").collect().foreach {
65-
case Row(vec1: Vector, vec2: Vector) =>
66-
assert(vec1 ~== vec2 absTol 1e-1)
67-
}
64+
selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
65+
.select("filtered", "preFilteredData").collect().foreach {
66+
case Row(vec1: Vector, vec2: Vector) =>
67+
assert(vec1 ~== vec2 absTol 1e-1)
68+
}
69+
70+
val preFilteredData2 = Seq(
71+
Vectors.dense(8.0, 7.0),
72+
Vectors.dense(0.0, 9.0),
73+
Vectors.dense(0.0, 9.0),
74+
Vectors.dense(8.0, 9.0)
75+
)
6876

77+
val df2 = sc.parallelize(data.zip(preFilteredData2))
78+
.map(x => (x._1.label, x._1.features, x._2))
79+
.toDF("label", "data", "preFilteredData")
80+
81+
selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
82+
.select("filtered", "preFilteredData").collect().foreach {
83+
case Row(vec1: Vector, vec2: Vector) =>
84+
assert(vec1 ~== vec2 absTol 1e-1)
85+
}
6986
}
7087

7188
test("ChiSqSelector read/write") {

0 commit comments

Comments
 (0)