Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.feature.ChiSqSelectorType
import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.rdd.RDD
Expand All @@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params
/**
* Number of features that selector will select (ordered by statistic value descending). If the
* number of features is less than numTopFeatures, then this will select all features.
* Only applicable when selectorType = "kbest".
* The default value of numTopFeatures is 50.
*
* @group param
*/
final val numTopFeatures = new IntParam(this, "numTopFeatures",
Expand All @@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params
/** @group getParam */
def getNumTopFeatures: Int = $(numTopFeatures)

/**
* Percentile of features that selector will select, ordered by statistics value descending.
* Only applicable when selectorType = "percentile".
* Default value is 0.1.
*/
final val percentile = new DoubleParam(this, "percentile",
"Percentile of features that selector will select, ordered by statistics value descending.",
ParamValidators.inRange(0, 1))
Expand All @@ -64,38 +71,40 @@ private[feature] trait ChiSqSelectorParams extends Params
/** @group getParam */
def getPercentile: Double = $(percentile)

final val alpha = new DoubleParam(this, "alpha",
"The highest p-value for features to be kept.",
/**
* The highest p-value for features to be kept.
* Only applicable when selectorType = "fpr".
* Default value is 0.05.
*/
final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
ParamValidators.inRange(0, 1))
setDefault(alpha -> 0.05)

/** @group getParam */
def getAlpha: Double = $(alpha)

/**
* The ChiSqSelector supports KBest, Percentile, FPR selection,
* which is the same as ChiSqSelectorType defined in MLLIB.
* when call setNumTopFeatures, the selectorType is set to KBest
* when call setPercentile, the selectorType is set to Percentile
* when call setAlpha, the selectorType is set to FPR
* The selector type of the ChisqSelector.
* Supported options: "kbest" (default), "percentile" and "fpr".
*/
final val selectorType = new Param[String](this, "selectorType",
"ChiSqSelector Type: KBest, Percentile, FPR")
setDefault(selectorType -> ChiSqSelectorType.KBest.toString)
"The selector type of the ChisqSelector. " +
"Supported options: kbest (default), percentile and fpr.",
ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
setDefault(selectorType -> OldChiSqSelector.KBest)

/** @group getParam */
def getChiSqSelectorType: String = $(selectorType)
def getSelectorType: String = $(selectorType)
}

/**
* Chi-Squared feature selection, which selects categorical features to use for predicting a
* categorical label.
* The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
* `KBest` chooses the `k` top features according to a chi-squared test.
* `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
* `FPR` chooses all features whose false positive rate meets some threshold.
* By default, the selection method is `KBest`, the default number of top features is 50.
* User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
* The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
* `kbest` chooses the `k` top features according to a chi-squared test.
* `percentile` is similar but chooses a fraction of all features instead of a fixed number.
* `fpr` chooses all features whose false positive rate meets some threshold.
* By default, the selection method is `kbest`, the default number of top features is 50.
*/
@Since("1.6.0")
final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
Expand All @@ -104,24 +113,21 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
@Since("1.6.0")
def this() = this(Identifiable.randomUID("chiSqSelector"))

/** @group setParam */
@Since("2.1.0")
def setSelectorType(value: String): this.type = set(selectorType, value)

/** @group setParam */
@Since("1.6.0")
def setNumTopFeatures(value: Int): this.type = {
set(selectorType, ChiSqSelectorType.KBest.toString)
set(numTopFeatures, value)
}
def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)

/** @group setParam */
@Since("2.1.0")
def setPercentile(value: Double): this.type = {
set(selectorType, ChiSqSelectorType.Percentile.toString)
set(percentile, value)
}
def setPercentile(value: Double): this.type = set(percentile, value)

/** @group setParam */
@Since("2.1.0")
def setAlpha(value: Double): this.type = {
set(selectorType, ChiSqSelectorType.FPR.toString)
set(alpha, value)
}
def setAlpha(value: Double): this.type = set(alpha, value)

/** @group setParam */
@Since("1.6.0")
Expand All @@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
case Row(label: Double, features: Vector) =>
OldLabeledPoint(label, OldVectors.fromML(features))
}
var selector = new feature.ChiSqSelector()
ChiSqSelectorType.withName($(selectorType)) match {
case ChiSqSelectorType.KBest =>
selector.setNumTopFeatures($(numTopFeatures))
case ChiSqSelectorType.Percentile =>
selector.setPercentile($(percentile))
case ChiSqSelectorType.FPR =>
selector.setAlpha($(alpha))
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
val selector = new feature.ChiSqSelector()
.setSelectorType($(selectorType))
.setNumTopFeatures($(numTopFeatures))
.setPercentile($(percentile))
.setAlpha($(alpha))
val model = selector.fit(input)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
otherPairs.foreach { case (_, paramName: String) =>
if (isSet(getParam(paramName))) {
logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
}
}
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkNumericType(schema, $(labelCol))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable {
}

/**
* Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a
* Java stub for ChiSqSelector.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitChiSqSelectorKBest(numTopFeatures: Int,
data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd)
}

/**
* Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitChiSqSelectorPercentile(percentile: Double,
data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector().setPercentile(percentile).fit(data.rdd)
}

/**
* Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector().setAlpha(alpha).fit(data.rdd)
def fitChiSqSelector(
selectorType: String,
numTopFeatures: Int,
percentile: Double,
alpha: Double,
data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector()
.setSelectorType(selectorType)
.setNumTopFeatures(numTopFeatures)
.setPercentile(percentile)
.setAlpha(alpha)
.fit(data.rdd)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}

@Since("2.1.0")
private[spark] object ChiSqSelectorType extends Enumeration {
type SelectorType = Value
val KBest, Percentile, FPR = Value
}

/**
* Chi Squared selector model.
*
Expand Down Expand Up @@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {

/**
* Creates a ChiSquared feature selector.
* The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
* `KBest` chooses the `k` top features according to a chi-squared test.
* `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
* `FPR` chooses all features whose false positive rate meets some threshold.
* By default, the selection method is `KBest`, the default number of top features is 50.
* User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
* The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
* `kbest` chooses the `k` top features according to a chi-squared test.
* `percentile` is similar but chooses a fraction of all features instead of a fixed number.
* `fpr` chooses all features whose false positive rate meets some threshold.
* By default, the selection method is `kbest`, the default number of top features is 50.
*/
@Since("1.3.0")
class ChiSqSelector @Since("2.1.0") () extends Serializable {
var numTopFeatures: Int = 50
var percentile: Double = 0.1
var alpha: Double = 0.05
var selectorType = ChiSqSelectorType.KBest
var selectorType = ChiSqSelector.KBest

/**
* The is the same to call this() and setNumTopFeatures(numTopFeatures)
Expand All @@ -192,28 +185,27 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
@Since("1.6.0")
def setNumTopFeatures(value: Int): this.type = {
numTopFeatures = value
selectorType = ChiSqSelectorType.KBest
this
}

@Since("2.1.0")
def setPercentile(value: Double): this.type = {
require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]")
percentile = value
selectorType = ChiSqSelectorType.Percentile
this
}

@Since("2.1.0")
def setAlpha(value: Double): this.type = {
require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
alpha = value
selectorType = ChiSqSelectorType.FPR
this
}

@Since("2.1.0")
def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
def setSelectorType(value: String): this.type = {
require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
s"ChiSqSelector Type: $value was not supported.")
selectorType = value
this
}
Expand All @@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
val chiSqTestResult = Statistics.chiSqTest(data)
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
val features = selectorType match {
case ChiSqSelectorType.KBest => chiSqTestResult
case ChiSqSelector.KBest => chiSqTestResult
.take(numTopFeatures)
case ChiSqSelectorType.Percentile => chiSqTestResult
case ChiSqSelector.Percentile => chiSqTestResult
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelectorType.FPR => chiSqTestResult
case ChiSqSelector.FPR => chiSqTestResult
.filter{ case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
Expand All @@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
}
}

@Since("2.1.0")
object ChiSqSelector {

/** String name for `kbest` selector type. */
private[spark] val KBest: String = "kbest"

/** String name for `percentile` selector type. */
private[spark] val Percentile: String = "percentile"

/** String name for `fpr` selector type. */
private[spark] val FPR: String = "fpr"

/** Set of selector type and param pairs that ChiSqSelector supports. */
private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
Percentile -> "percentile", FPR -> "alpha")

/** Set of selector types that ChiSqSelector supports. */
private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
.toDF("label", "data", "preFilteredData")

val selector = new ChiSqSelector()
.setSelectorType("kbest")
.setNumTopFeatures(1)
.setFeaturesCol("data")
.setLabelCol("label")
Expand All @@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(vec1 ~== vec2 absTol 1e-1)
}

selector.setPercentile(0.34).fit(df).transform(df)
.select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}
selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
.select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}

val preFilteredData2 = Seq(
Vectors.dense(8.0, 7.0),
Vectors.dense(0.0, 9.0),
Vectors.dense(0.0, 9.0),
Vectors.dense(8.0, 9.0)
)

val df2 = sc.parallelize(data.zip(preFilteredData2))
.map(x => (x._1.label, x._1.features, x._2))
.toDF("label", "data", "preFilteredData")

selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
.select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}
}

test("ChiSqSelector read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(2.0, Vectors.dense(Array(9.0))))
val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added ML test case.

val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))
}.collect().toSet
Expand Down
Loading