-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-7888] Be able to disable intercept in linear regression in ml package #6927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
006246c
91ffc0a
5e84a0b
e2140ba
0b0c8c0
7015b9f
3bb9ee1
319bd3f
f34971c
ae5baa8
4016fac
0ad384c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row} | |
| class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
|
|
||
| @transient var dataset: DataFrame = _ | ||
| @transient var datasetWithoutIntercept: DataFrame = _ | ||
|
|
||
| /** | ||
| * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML | ||
|
|
@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| * | ||
| * import org.apache.spark.mllib.util.LinearDataGenerator | ||
| * val data = | ||
| * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) | ||
| * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") | ||
| * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), | ||
| * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) | ||
| * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) | ||
| * .saveAsTextFile("path") | ||
| */ | ||
| override def beforeAll(): Unit = { | ||
| super.beforeAll() | ||
| dataset = sqlContext.createDataFrame( | ||
| sc.parallelize(LinearDataGenerator.generateLinearInput( | ||
| 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) | ||
| /** | ||
| * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this too long?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 99 chars :) |
||
| * training model without intercept | ||
| */ | ||
| datasetWithoutIntercept = sqlContext.createDataFrame( | ||
| sc.parallelize(LinearDataGenerator.generateLinearInput( | ||
| 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I got it. let's call it
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, for correctness testing,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good I'll rename it. I just wanted to have a test case where the without intercept model would potentially be fit better. |
||
| } | ||
|
|
||
| test("linear regression with intercept without regularization") { | ||
|
|
@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
|
|
||
| test("linear regression without intercept without regularization") { | ||
| val trainer = (new LinearRegression).setFitIntercept(false) | ||
| val model = trainer.fit(dataset) | ||
| val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) | ||
|
|
||
| /** | ||
| * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, | ||
| * intercept = FALSE)) | ||
| * > weights | ||
| * 3 x 1 sparse Matrix of class "dgCMatrix" | ||
| * s0 | ||
| * (Intercept) . | ||
| * as.numeric.data.V2. 6.995908 | ||
| * as.numeric.data.V3. 5.275131 | ||
| */ | ||
| val weightsR = Array(6.995908, 5.275131) | ||
|
|
||
| assert(model.intercept ~== 0 relTol 1E-3) | ||
| assert(model.weights(0) ~== weightsR(0) relTol 1E-3) | ||
| assert(model.weights(1) ~== weightsR(1) relTol 1E-3) | ||
| /** | ||
| * Then again with the data with no intercept: | ||
| * > weightsWithoutIntercept | ||
| * 3 x 1 sparse Matrix of class "dgCMatrix" | ||
| * s0 | ||
| * (Intercept) . | ||
| * as.numeric.data3.V2. 4.70011 | ||
| * as.numeric.data3.V3. 7.19943 | ||
| */ | ||
| val weightsWithoutInterceptR = Array(4.70011, 7.19943) | ||
|
|
||
| assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) | ||
| assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) | ||
| assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) | ||
| } | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the extra line. |
||
| test("linear regression with intercept with L1 regularization") { | ||
| val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) | ||
| val model = trainer.fit(dataset) | ||
|
|
@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| * > weights | ||
| * 3 x 1 sparse Matrix of class "dgCMatrix" | ||
| * s0 | ||
| * (Intercept) 6.311546 | ||
| * as.numeric.data.V2. 2.123522 | ||
| * as.numeric.data.V3. 4.605651 | ||
| * (Intercept) 6.24300 | ||
| * as.numeric.data.V2. 4.024821 | ||
| * as.numeric.data.V3. 6.679841 | ||
| */ | ||
| val interceptR = 6.243000 | ||
| val interceptR = 6.24300 | ||
| val weightsR = Array(4.024821, 6.679841) | ||
|
|
||
| assert(model.intercept ~== interceptR relTol 1E-3) | ||
|
|
@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
|
|
||
| test("linear regression without intercept with L1 regularization") { | ||
| val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) | ||
| .setFitIntercept(false) | ||
| val model = trainer.fit(dataset) | ||
|
|
||
| /** | ||
| * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, | ||
| * intercept=FALSE)) | ||
| * > weights | ||
| * 3 x 1 sparse Matrix of class "dgCMatrix" | ||
| * s0 | ||
| * (Intercept) . | ||
| * as.numeric.data.V2. 6.299752 | ||
| * as.numeric.data.V3. 4.772913 | ||
| */ | ||
| val interceptR = 0.0 | ||
| val weightsR = Array(6.299752, 4.772913) | ||
|
|
||
| assert(model.intercept ~== interceptR relTol 1E-3) | ||
| assert(model.weights(0) ~== weightsR(0) relTol 1E-3) | ||
| assert(model.weights(1) ~== weightsR(1) relTol 1E-3) | ||
|
|
||
| model.transform(dataset).select("features", "prediction").collect().foreach { | ||
| case Row(features: DenseVector, prediction1: Double) => | ||
| val prediction2 = | ||
| features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept | ||
| assert(prediction1 ~== prediction2 relTol 1E-5) | ||
| } | ||
| } | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove new line |
||
| test("linear regression with intercept with L2 regularization") { | ||
| val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) | ||
| val model = trainer.fit(dataset) | ||
|
|
@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
|
|
||
| test("linear regression without intercept with L2 regularization") { | ||
| val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) | ||
| .setFitIntercept(false) | ||
| val model = trainer.fit(dataset) | ||
|
|
||
| /** | ||
| * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, | ||
| * intercept = FALSE)) | ||
| * > weights | ||
| * 3 x 1 sparse Matrix of class "dgCMatrix" | ||
| * s0 | ||
| * (Intercept) . | ||
| * as.numeric.data.V2. 5.522875 | ||
| * as.numeric.data.V3. 4.214502 | ||
| */ | ||
| val interceptR = 0.0 | ||
| val weightsR = Array(5.522875, 4.214502) | ||
|
|
||
| assert(model.intercept ~== interceptR relTol 1E-3) | ||
| assert(model.weights(0) ~== weightsR(0) relTol 1E-3) | ||
| assert(model.weights(1) ~== weightsR(1) relTol 1E-3) | ||
|
|
||
| model.transform(dataset).select("features", "prediction").collect().foreach { | ||
| case Row(features: DenseVector, prediction1: Double) => | ||
| val prediction2 = | ||
| features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept | ||
| assert(prediction1 ~== prediction2 relTol 1E-5) | ||
| } | ||
| } | ||
|
|
||
| test("linear regression with intercept with ElasticNet regularization") { | ||
| val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) | ||
| val model = trainer.fit(dataset) | ||
|
|
@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(prediction1 ~== prediction2 relTol 1E-5) | ||
| } | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new line |
||
|
|
||
| test("linear regression without intercept with ElasticNet regularization") { | ||
| val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) | ||
| .setFitIntercept(false) | ||
| val model = trainer.fit(dataset) | ||
|
|
||
| /** | ||
| * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, | ||
| * intercept=FALSE)) | ||
| * > weights | ||
| * 3 x 1 sparse Matrix of class "dgCMatrix" | ||
| * s0 | ||
| * (Intercept) . | ||
| * as.numeric.dataM.V2. 5.673348 | ||
| * as.numeric.dataM.V3. 4.322251 | ||
| */ | ||
| val interceptR = 0.0 | ||
| val weightsR = Array(5.673348, 4.322251) | ||
|
|
||
| assert(model.intercept ~== interceptR relTol 1E-3) | ||
| assert(model.weights(0) ~== weightsR(0) relTol 1E-3) | ||
| assert(model.weights(1) ~== weightsR(1) relTol 1E-3) | ||
|
|
||
| model.transform(dataset).select("features", "prediction").collect().foreach { | ||
| case Row(features: DenseVector, prediction1: Double) => | ||
| val prediction2 = | ||
| features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept | ||
| assert(prediction1 ~== prediction2 relTol 1E-5) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove extra line.