Skip to content

Commit f7082ac

Browse files
committed
[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.
## What changes were proposed in this pull request? Several performance improvement for ```ChiSqSelector```: 1, Keep ```selectedFeatures``` ordered ascendent. ```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make prediction. We should sort it when training model rather than making prediction, since users usually train model once and use the model to do prediction multiple times. 2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to sort the ChiSq test result by statistic. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang <[email protected]> Closes #15277 from yanboliang/spark-17704.
1 parent a19a1bb commit f7082ac

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
3535
/**
3636
* Chi Squared selector model.
3737
*
38-
* @param selectedFeatures list of indices to select (filter).
38+
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
3939
*/
4040
@Since("1.3.0")
4141
class ChiSqSelectorModel @Since("1.3.0") (
4242
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
4343

44+
require(isSorted(selectedFeatures), "Array has to be sorted asc")
45+
46+
protected def isSorted(array: Array[Int]): Boolean = {
47+
var i = 1
48+
val len = array.length
49+
while (i < len) {
50+
if (array(i) < array(i-1)) return false
51+
i += 1
52+
}
53+
true
54+
}
55+
4456
/**
4557
* Applies transformation on a vector.
4658
*
@@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
5769
* Preserves the order of filtered features the same as their indices are stored.
5870
* Might be moved to Vector as .slice
5971
* @param features vector
60-
* @param filterIndices indices of features to filter
72+
* @param filterIndices indices of features to filter, must be ordered asc
6173
*/
6274
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
63-
val orderedIndices = filterIndices.sorted
6475
features match {
6576
case SparseVector(size, indices, values) =>
66-
val newSize = orderedIndices.length
77+
val newSize = filterIndices.length
6778
val newValues = new ArrayBuilder.ofDouble
6879
val newIndices = new ArrayBuilder.ofInt
6980
var i = 0
7081
var j = 0
7182
var indicesIdx = 0
7283
var filterIndicesIdx = 0
73-
while (i < indices.length && j < orderedIndices.length) {
84+
while (i < indices.length && j < filterIndices.length) {
7485
indicesIdx = indices(i)
75-
filterIndicesIdx = orderedIndices(j)
86+
filterIndicesIdx = filterIndices(j)
7687
if (indicesIdx == filterIndicesIdx) {
7788
newIndices += j
7889
newValues += values(i)
@@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
90101
Vectors.sparse(newSize, newIndices.result(), newValues.result())
91102
case DenseVector(values) =>
92103
val values = features.toArray
93-
Vectors.dense(orderedIndices.map(i => values(i)))
104+
Vectors.dense(filterIndices.map(i => values(i)))
94105
case other =>
95106
throw new UnsupportedOperationException(
96107
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
@@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
220231
@Since("1.3.0")
221232
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
222233
val chiSqTestResult = Statistics.chiSqTest(data)
223-
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
224234
val features = selectorType match {
225-
case ChiSqSelector.KBest => chiSqTestResult
226-
.take(numTopFeatures)
227-
case ChiSqSelector.Percentile => chiSqTestResult
228-
.take((chiSqTestResult.length * percentile).toInt)
229-
case ChiSqSelector.FPR => chiSqTestResult
230-
.filter{ case (res, _) => res.pValue < alpha }
235+
case ChiSqSelector.KBest =>
236+
chiSqTestResult.zipWithIndex
237+
.sortBy { case (res, _) => -res.statistic }
238+
.take(numTopFeatures)
239+
case ChiSqSelector.Percentile =>
240+
chiSqTestResult.zipWithIndex
241+
.sortBy { case (res, _) => -res.statistic }
242+
.take((chiSqTestResult.length * percentile).toInt)
243+
case ChiSqSelector.FPR =>
244+
chiSqTestResult.zipWithIndex
245+
.filter{ case (res, _) => res.pValue < alpha }
231246
case errorType =>
232247
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
233248
}
234-
val indices = features.map { case (_, indices) => indices }
249+
val indices = features.map { case (_, indices) => indices }.sorted
235250
new ChiSqSelectorModel(indices)
236251
}
237252
}

project/MimaExcludes.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,6 @@ object MimaExcludes {
817817
) ++ Seq(
818818
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
819819
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
820-
) ++ Seq(
821-
// [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test
822-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
823820
) ++ Seq(
824821
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
825822
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")

0 commit comments

Comments
 (0)