Skip to content

Commit bd9663c

Browse files
author
Omede Firouz
committed
[MLLIB] Add fit intercept api to ml logisticregression
I have the fit intercept enabled by default for logistic regression, I wonder what others think here. I understand that it enables allocation by default which is undesirable, but one needs to have a very strong reason for not having an intercept term enabled so it is the safer default from a statistical sense. Explicitly modeling the intercept by adding a column of all 1s does not work. I believe the reason is that since the API for LogisticRegressionWithLBFGS forces column normalization, and a column of all 1s has 0 variance so dividing by 0 kills it.
1 parent 0e2753f commit bd9663c

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
3131
* Params for logistic regression.
3232
*/
3333
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
34-
with HasRegParam with HasMaxIter with HasThreshold
34+
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
3535

3636

3737
/**
@@ -46,6 +46,7 @@ class LogisticRegression
4646
with LogisticRegressionParams {
4747

4848
setRegParam(0.1)
49+
setFitIntercept(true)
4950
setMaxIter(100)
5051
setThreshold(0.5)
5152

@@ -55,6 +56,9 @@ class LogisticRegression
5556
/** @group setParam */
5657
def setMaxIter(value: Int): this.type = set(maxIter, value)
5758

59+
/** @group setParam */
60+
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
61+
5862
/** @group setParam */
5963
def setThreshold(value: Double): this.type = set(threshold, value)
6064

@@ -71,6 +75,7 @@ class LogisticRegression
7175
lr.optimizer
7276
.setRegParam(paramMap(regParam))
7377
.setNumIterations(paramMap(maxIter))
78+
.addIntercept(paramMap(fitIntercept))
7479
val oldModel = lr.run(oldDataset)
7580
val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
7681

mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ private[ml] trait HasProbabilityCol extends Params {
106106
def getProbabilityCol: String = get(probabilityCol)
107107
}
108108

109+
private[ml] trait HasFitIntercept extends Params {
110+
/**
111+
* param for fitting the intercept term
112+
* @group param
113+
*/
114+
val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "fits the intercept term or not")
115+
116+
/** @group getParam */
117+
def getFitIntercept: Boolean = get(fitIntercept)
118+
}
119+
109120
private[ml] trait HasThreshold extends Params {
110121
/**
111122
* param for threshold in (binary) prediction

0 commit comments

Comments
 (0)