Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
* @param selectedFeatures list of indices to select (filter).
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {

require(isSorted(selectedFeatures), "Array has to be sorted asc")

protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
val len = array.length
while (i < len) {
if (array(i) < array(i-1)) return false
i += 1
}
true
}

/**
* Applies transformation on a vector.
*
Expand All @@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
* @param filterIndices indices of features to filter
* @param filterIndices indices of features to filter, must be ordered asc
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
val orderedIndices = filterIndices.sorted
features match {
case SparseVector(size, indices, values) =>
val newSize = orderedIndices.length
val newSize = filterIndices.length
val newValues = new ArrayBuilder.ofDouble
val newIndices = new ArrayBuilder.ofInt
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
while (i < indices.length && j < orderedIndices.length) {
while (i < indices.length && j < filterIndices.length) {
indicesIdx = indices(i)
filterIndicesIdx = orderedIndices(j)
filterIndicesIdx = filterIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
Expand All @@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
Vectors.dense(orderedIndices.map(i => values(i)))
Vectors.dense(filterIndices.map(i => values(i)))
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
Expand Down Expand Up @@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data)
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
val features = selectorType match {
case ChiSqSelector.KBest => chiSqTestResult
.take(numTopFeatures)
case ChiSqSelector.Percentile => chiSqTestResult
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR => chiSqTestResult
.filter{ case (res, _) => res.pValue < alpha }
case ChiSqSelector.KBest =>
chiSqTestResult.zipWithIndex
.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
case ChiSqSelector.Percentile =>
chiSqTestResult.zipWithIndex
.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
chiSqTestResult.zipWithIndex
.filter{ case (res, _) => res.pValue < alpha }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fpr type, it's not necessary to compute .sortBy { case (res, _) => -res.statistic }.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true. Originally I though it would make sense to present the statistic to the model because it has no way to recover it, and further made sense to rank by the statistic anyway, but subsequent changes make that irrelevant.

case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
val indices = features.map { case (_, indices) => indices }
val indices = features.map { case (_, indices) => indices }.sorted
new ChiSqSelectorModel(indices)
}
}
Expand Down
3 changes: 0 additions & 3 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -817,9 +817,6 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
) ++ Seq(
// [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
) ++ Seq(
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
Expand Down