Skip to content

Commit 54b7b31

Browse files
committed
Fixed issue with logreg threshold being set correctly
1 parent 0617d61 commit 54b7b31

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3030
import org.apache.spark.storage.StorageLevel
3131

3232
/**
33-
* :: AlphaComponent ::
3433
* Params for logistic regression.
3534
*/
36-
@AlphaComponent
3735
private[classification] trait LogisticRegressionParams extends ClassifierParams
3836
with HasRegParam with HasMaxIter with HasThreshold with HasScoreCol {
3937

@@ -53,9 +51,11 @@ private[classification] trait LogisticRegressionParams extends ClassifierParams
5351

5452

5553
/**
54+
* :: AlphaComponent ::
5655
* Logistic regression.
5756
* Currently, this class only supports binary classification.
5857
*/
58+
@AlphaComponent
5959
class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressionModel]
6060
with LogisticRegressionParams {
6161

@@ -106,14 +106,19 @@ class LogisticRegressionModel private[ml] (
106106
with ProbabilisticClassificationModel
107107
with LogisticRegressionParams {
108108

109+
setThreshold(0.5)
110+
109111
def setThreshold(value: Double): this.type = {
110112
this.threshold_internal = value
111113
set(threshold, value)
112114
}
113115
def setScoreCol(value: String): this.type = set(scoreCol, value)
114116

115-
/** Store for faster test-time prediction. */
116-
private var threshold_internal: Double = this.getThreshold
117+
/**
118+
* Store for faster test-time prediction.
119+
* Initialized to threshold in fittingParamMap if exists, else default threshold.
120+
*/
121+
private var threshold_internal: Double = fittingParamMap.get(threshold).getOrElse(getThreshold)
117122

118123
private val margin: Vector => Double = (features) => {
119124
BLAS.dot(features, weights) + intercept

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import org.apache.spark.storage.StorageLevel
2828
/**
2929
* Params for linear regression.
3030
*/
31-
@AlphaComponent
3231
private[regression] trait LinearRegressionParams extends RegressorParams
3332
with HasRegParam with HasMaxIter
3433

0 commit comments

Comments
 (0)