@@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
3535/**
3636 * Chi Squared selector model.
3737 *
38- * @param selectedFeatures list of indices to select (filter).
38+ * @param selectedFeatures list of indices to select (filter). Must be ordered asc
3939 */
4040@ Since (" 1.3.0" )
4141class ChiSqSelectorModel @ Since (" 1.3.0" ) (
4242 @ Since (" 1.3.0" ) val selectedFeatures : Array [Int ]) extends VectorTransformer with Saveable {
4343
44+ require(isSorted(selectedFeatures), " Array has to be sorted asc" )
45+
46+ protected def isSorted (array : Array [Int ]): Boolean = {
47+ var i = 1
48+ val len = array.length
49+ while (i < len) {
50+ if (array(i) < array(i- 1 )) return false
51+ i += 1
52+ }
53+ true
54+ }
55+
4456 /**
4557 * Applies transformation on a vector.
4658 *
@@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
5769 * Preserves the order of filtered features the same as their indices are stored.
5870 * Might be moved to Vector as .slice
5971 * @param features vector
60- * @param filterIndices indices of features to filter
72+ * @param filterIndices indices of features to filter, must be ordered asc
6173 */
6274 private def compress (features : Vector , filterIndices : Array [Int ]): Vector = {
63- val orderedIndices = filterIndices.sorted
6475 features match {
6576 case SparseVector (size, indices, values) =>
66- val newSize = orderedIndices .length
77+ val newSize = filterIndices .length
6778 val newValues = new ArrayBuilder .ofDouble
6879 val newIndices = new ArrayBuilder .ofInt
6980 var i = 0
7081 var j = 0
7182 var indicesIdx = 0
7283 var filterIndicesIdx = 0
73- while (i < indices.length && j < orderedIndices .length) {
84+ while (i < indices.length && j < filterIndices .length) {
7485 indicesIdx = indices(i)
75- filterIndicesIdx = orderedIndices (j)
86+ filterIndicesIdx = filterIndices (j)
7687 if (indicesIdx == filterIndicesIdx) {
7788 newIndices += j
7889 newValues += values(i)
@@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
90101 Vectors .sparse(newSize, newIndices.result(), newValues.result())
91102 case DenseVector (values) =>
92103 val values = features.toArray
93- Vectors .dense(orderedIndices .map(i => values(i)))
104+ Vectors .dense(filterIndices .map(i => values(i)))
94105 case other =>
95106 throw new UnsupportedOperationException (
96107 s " Only sparse and dense vectors are supported but got ${other.getClass}. " )
@@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
220231 @ Since (" 1.3.0" )
221232 def fit (data : RDD [LabeledPoint ]): ChiSqSelectorModel = {
222233 val chiSqTestResult = Statistics .chiSqTest(data)
223- .zipWithIndex.sortBy { case (res, _) => - res.statistic }
224234 val features = selectorType match {
225- case ChiSqSelector .KBest => chiSqTestResult
226- .take(numTopFeatures)
227- case ChiSqSelector .Percentile => chiSqTestResult
228- .take((chiSqTestResult.length * percentile).toInt)
229- case ChiSqSelector .FPR => chiSqTestResult
230- .filter{ case (res, _) => res.pValue < alpha }
235+ case ChiSqSelector .KBest =>
236+ chiSqTestResult.zipWithIndex
237+ .sortBy { case (res, _) => - res.statistic }
238+ .take(numTopFeatures)
239+ case ChiSqSelector .Percentile =>
240+ chiSqTestResult.zipWithIndex
241+ .sortBy { case (res, _) => - res.statistic }
242+ .take((chiSqTestResult.length * percentile).toInt)
243+ case ChiSqSelector .FPR =>
244+ chiSqTestResult.zipWithIndex
245+ .filter{ case (res, _) => res.pValue < alpha }
231246 case errorType =>
232247 throw new IllegalStateException (s " Unknown ChiSqSelector Type: $errorType" )
233248 }
234- val indices = features.map { case (_, indices) => indices }
249+ val indices = features.map { case (_, indices) => indices }.sorted
235250 new ChiSqSelectorModel (indices)
236251 }
237252}
0 commit comments