@@ -30,10 +30,8 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3030import org .apache .spark .storage .StorageLevel
3131
3232/**
33- * :: AlphaComponent ::
3433 * Params for logistic regression.
3534 */
36- @ AlphaComponent
3735private [classification] trait LogisticRegressionParams extends ClassifierParams
3836 with HasRegParam with HasMaxIter with HasThreshold with HasScoreCol {
3937
@@ -53,9 +51,11 @@ private[classification] trait LogisticRegressionParams extends ClassifierParams
5351
5452
5553/**
54+ * :: AlphaComponent ::
5655 * Logistic regression.
5756 * Currently, this class only supports binary classification.
5857 */
58+ @ AlphaComponent
5959class LogisticRegression extends Classifier [LogisticRegression , LogisticRegressionModel ]
6060 with LogisticRegressionParams {
6161
@@ -106,14 +106,19 @@ class LogisticRegressionModel private[ml] (
106106 with ProbabilisticClassificationModel
107107 with LogisticRegressionParams {
108108
109+ setThreshold(0.5 )
110+
109111 def setThreshold (value : Double ): this .type = {
110112 this .threshold_internal = value
111113 set(threshold, value)
112114 }
113115 def setScoreCol (value : String ): this .type = set(scoreCol, value)
114116
115- /** Store for faster test-time prediction. */
116- private var threshold_internal : Double = this .getThreshold
117+ /**
118+ * Store for faster test-time prediction.
119+ * Initialized to threshold in fittingParamMap if exists, else default threshold.
120+ */
121+ private var threshold_internal : Double = fittingParamMap.get(threshold).getOrElse(getThreshold)
117122
118123 private val margin : Vector => Double = (features) => {
119124 BLAS .dot(features, weights) + intercept
0 commit comments