Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
Expand All @@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
* Params for linear regression.
*/
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept

/**
* :: Experimental ::
Expand Down Expand Up @@ -72,6 +73,14 @@ class LinearRegression(override val uid: String)
def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0)

/**
* Set if we should fit the intercept
* Default is true.
* @group setParam
*/
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra line.

/**
* Set the ElasticNet mixing parameter.
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
Expand Down Expand Up @@ -123,6 +132,7 @@ class LinearRegression(override val uid: String)
val numFeatures = summarizer.mean.size
val yMean = statCounter.mean
val yStd = math.sqrt(statCounter.variance)
// look at glmnet5.m L761 maaaybe that has info

// If the yStd is zero, then the intercept is yMean with zero weights;
// as a result, training is not needed.
Expand All @@ -142,7 +152,7 @@ class LinearRegression(override val uid: String)
val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam

val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
featuresStd, featuresMean, effectiveL2RegParam)

val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
Expand Down Expand Up @@ -180,7 +190,7 @@ class LinearRegression(override val uid: String)
// The intercept in R's GLMNET is computed using closed form after the coefficients are
// converged. See the following discussion for detail.
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
if (handlePersistence) instances.unpersist()

// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
Expand Down Expand Up @@ -232,13 +242,18 @@ class LinearRegressionModel private[ml] (
* See this discussion for detail.
* http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
*
* When training with intercept enabled,
* The objective function in the scaled space is given by
* {{{
* L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
* }}}
* where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
* \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
*
* If we fitting the intercept disabled (that is forced through 0.0),
* we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
* of the respective means.
*
* This can be rewritten as
* {{{
* L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
Expand All @@ -253,6 +268,7 @@ class LinearRegressionModel private[ml] (
* \sum_i w_i^\prime x_i - y / \hat{y} + offset
* }}}
*
*
* Note that the effective weights and offset don't depend on training dataset,
* so they can be precomputed.
*
Expand Down Expand Up @@ -299,6 +315,7 @@ private class LeastSquaresAggregator(
weights: Vector,
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double]) extends Serializable {

Expand All @@ -319,7 +336,7 @@ private class LeastSquaresAggregator(
}
i += 1
}
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
(weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
}

private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
Expand Down Expand Up @@ -402,6 +419,7 @@ private class LeastSquaresCostFun(
data: RDD[(Double, Vector)],
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
Expand All @@ -410,7 +428,7 @@ private class LeastSquaresCostFun(
val w = Vectors.fromBreeze(weights)

val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
labelMean, featuresStd, featuresMean))(
labelMean, fitIntercept, featuresStd, featuresMean))(
seqOp = (c, v) => (c, v) match {
case (aggregator, (label, features)) => aggregator.add(label, features)
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

@transient var dataset: DataFrame = _
@transient var datasetWithoutIntercept: DataFrame = _

/**
* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
Expand All @@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
*
* import org.apache.spark.mllib.util.LinearDataGenerator
* val data =
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
* Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
* .saveAsTextFile("path")
*/
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
/**
* datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this too long?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

99 chars :)

* training model without intercept
*/
datasetWithoutIntercept = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need datasetNR? They are the same as dataset. Just the script generating the dataset is wrong. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I got it. let's call it datasetWithoutIntercept or have proper comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, for correctness testing, datasetNR is not required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good I'll rename it. I just wanted to have a test case where the without intercept model would potentially be fit better.

}

test("linear regression with intercept without regularization") {
Expand Down Expand Up @@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("linear regression without intercept without regularization") {
val trainer = (new LinearRegression).setFitIntercept(false)
val model = trainer.fit(dataset)
val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)

/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
* intercept = FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data.V2. 6.995908
* as.numeric.data.V3. 5.275131
*/
val weightsR = Array(6.995908, 5.275131)

assert(model.intercept ~== 0 relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
/**
* Then again with the data with no intercept:
* > weightsWithoutIntercept
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data3.V2. 4.70011
* as.numeric.data3.V3. 7.19943
*/
val weightsWithoutInterceptR = Array(4.70011, 7.19943)

assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the extra line.

test("linear regression with intercept with L1 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
val model = trainer.fit(dataset)
Expand All @@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) 6.311546
* as.numeric.data.V2. 2.123522
* as.numeric.data.V3. 4.605651
* (Intercept) 6.24300
* as.numeric.data.V2. 4.024821
* as.numeric.data.V3. 6.679841
*/
val interceptR = 6.243000
val interceptR = 6.24300
val weightsR = Array(4.024821, 6.679841)

assert(model.intercept ~== interceptR relTol 1E-3)
Expand All @@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("linear regression without intercept with L1 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
.setFitIntercept(false)
val model = trainer.fit(dataset)

/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
* intercept=FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data.V2. 6.299752
* as.numeric.data.V3. 4.772913
*/
val interceptR = 0.0
val weightsR = Array(6.299752, 4.772913)

assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)

model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove new line

test("linear regression with intercept with L2 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
val model = trainer.fit(dataset)
Expand Down Expand Up @@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("linear regression without intercept with L2 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
.setFitIntercept(false)
val model = trainer.fit(dataset)

/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
* intercept = FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data.V2. 5.522875
* as.numeric.data.V3. 4.214502
*/
val interceptR = 0.0
val weightsR = Array(5.522875, 4.214502)

assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)

model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}

test("linear regression with intercept with ElasticNet regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
val model = trainer.fit(dataset)
Expand Down Expand Up @@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new line


test("linear regression without intercept with ElasticNet regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
.setFitIntercept(false)
val model = trainer.fit(dataset)

/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
* intercept=FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.dataM.V2. 5.673348
* as.numeric.dataM.V3. 4.322251
*/
val interceptR = 0.0
val weightsR = Array(5.673348, 4.322251)

assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)

model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
}
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ object MimaExcludes {
// Removing a testing method from a private class
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
// While private MiMa is still not happy about the changes,
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution")
)
Expand Down