@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
2626class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
2727
2828 @ transient var dataset : DataFrame = _
29+ @ transient var datasetWithoutIntercept : DataFrame = _
2930
3031 /**
3132 * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
3435 *
3536 * import org.apache.spark.mllib.util.LinearDataGenerator
3637 * val data =
37- * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
38- * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
38+ * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
39+ * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
40+ * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
41+ * .saveAsTextFile("path")
3942 */
4043 override def beforeAll (): Unit = {
4144 super .beforeAll()
4245 dataset = sqlContext.createDataFrame(
4346 sc.parallelize(LinearDataGenerator .generateLinearInput(
4447 6.3 , Array (4.7 , 7.2 ), Array (0.9 , - 1.3 ), Array (0.7 , 1.2 ), 10000 , 42 , 0.1 ), 2 ))
48+ /**
49+ * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
50+ * training model without intercept
51+ */
52+ datasetWithoutIntercept = sqlContext.createDataFrame(
53+ sc.parallelize(LinearDataGenerator .generateLinearInput(
54+ 0.0 , Array (4.7 , 7.2 ), Array (0.9 , - 1.3 ), Array (0.7 , 1.2 ), 10000 , 42 , 0.1 ), 2 ))
55+
4556 }
4657
4758 test(" linear regression with intercept without regularization" ) {
@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
7889 }
7990 }
8091
92+ test(" linear regression without intercept without regularization" ) {
93+ val trainer = (new LinearRegression ).setFitIntercept(false )
94+ val model = trainer.fit(dataset)
95+ val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
96+
97+ /**
98+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
99+ * intercept = FALSE))
100+ * > weights
101+ * 3 x 1 sparse Matrix of class "dgCMatrix"
102+ * s0
103+ * (Intercept) .
104+ * as.numeric.data.V2. 6.995908
105+ * as.numeric.data.V3. 5.275131
106+ */
107+ val weightsR = Array (6.995908 , 5.275131 )
108+
109+ assert(model.intercept ~== 0 relTol 1E-3 )
110+ assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-3 )
111+ assert(model.weights(1 ) ~== weightsR(1 ) relTol 1E-3 )
112+ /**
113+ * Then again with the data with no intercept:
114+ * > weightsWithoutIntercept
115+ * 3 x 1 sparse Matrix of class "dgCMatrix"
116+ * s0
117+ * (Intercept) .
118+ * as.numeric.data3.V2. 4.70011
119+ * as.numeric.data3.V3. 7.19943
120+ */
121+ val weightsWithoutInterceptR = Array (4.70011 , 7.19943 )
122+
123+ assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3 )
124+ assert(modelWithoutIntercept.weights(0 ) ~== weightsWithoutInterceptR(0 ) relTol 1E-3 )
125+ assert(modelWithoutIntercept.weights(1 ) ~== weightsWithoutInterceptR(1 ) relTol 1E-3 )
126+ }
127+
81128 test(" linear regression with intercept with L1 regularization" ) {
82129 val trainer = (new LinearRegression ).setElasticNetParam(1.0 ).setRegParam(0.57 )
83130 val model = trainer.fit(dataset)
@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
87134 * > weights
88135 * 3 x 1 sparse Matrix of class "dgCMatrix"
89136 * s0
90- * (Intercept) 6.311546
91- * as.numeric.data.V2. 2.123522
92- * as.numeric.data.V3. 4.605651
137+ * (Intercept) 6.24300
138+ * as.numeric.data.V2. 4.024821
139+ * as.numeric.data.V3. 6.679841
93140 */
94- val interceptR = 6.243000
141+ val interceptR = 6.24300
95142 val weightsR = Array (4.024821 , 6.679841 )
96143
97144 assert(model.intercept ~== interceptR relTol 1E-3 )
@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
106153 }
107154 }
108155
156+ test(" linear regression without intercept with L1 regularization" ) {
157+ val trainer = (new LinearRegression ).setElasticNetParam(1.0 ).setRegParam(0.57 )
158+ .setFitIntercept(false )
159+ val model = trainer.fit(dataset)
160+
161+ /**
162+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
163+ * intercept=FALSE))
164+ * > weights
165+ * 3 x 1 sparse Matrix of class "dgCMatrix"
166+ * s0
167+ * (Intercept) .
168+ * as.numeric.data.V2. 6.299752
169+ * as.numeric.data.V3. 4.772913
170+ */
171+ val interceptR = 0.0
172+ val weightsR = Array (6.299752 , 4.772913 )
173+
174+ assert(model.intercept ~== interceptR relTol 1E-3 )
175+ assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-3 )
176+ assert(model.weights(1 ) ~== weightsR(1 ) relTol 1E-3 )
177+
178+ model.transform(dataset).select(" features" , " prediction" ).collect().foreach {
179+ case Row (features : DenseVector , prediction1 : Double ) =>
180+ val prediction2 =
181+ features(0 ) * model.weights(0 ) + features(1 ) * model.weights(1 ) + model.intercept
182+ assert(prediction1 ~== prediction2 relTol 1E-5 )
183+ }
184+ }
185+
109186 test(" linear regression with intercept with L2 regularization" ) {
110187 val trainer = (new LinearRegression ).setElasticNetParam(0.0 ).setRegParam(2.3 )
111188 val model = trainer.fit(dataset)
@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
134211 }
135212 }
136213
214+ test(" linear regression without intercept with L2 regularization" ) {
215+ val trainer = (new LinearRegression ).setElasticNetParam(0.0 ).setRegParam(2.3 )
216+ .setFitIntercept(false )
217+ val model = trainer.fit(dataset)
218+
219+ /**
220+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
221+ * intercept = FALSE))
222+ * > weights
223+ * 3 x 1 sparse Matrix of class "dgCMatrix"
224+ * s0
225+ * (Intercept) .
226+ * as.numeric.data.V2. 5.522875
227+ * as.numeric.data.V3. 4.214502
228+ */
229+ val interceptR = 0.0
230+ val weightsR = Array (5.522875 , 4.214502 )
231+
232+ assert(model.intercept ~== interceptR relTol 1E-3 )
233+ assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-3 )
234+ assert(model.weights(1 ) ~== weightsR(1 ) relTol 1E-3 )
235+
236+ model.transform(dataset).select(" features" , " prediction" ).collect().foreach {
237+ case Row (features : DenseVector , prediction1 : Double ) =>
238+ val prediction2 =
239+ features(0 ) * model.weights(0 ) + features(1 ) * model.weights(1 ) + model.intercept
240+ assert(prediction1 ~== prediction2 relTol 1E-5 )
241+ }
242+ }
243+
137244 test(" linear regression with intercept with ElasticNet regularization" ) {
138245 val trainer = (new LinearRegression ).setElasticNetParam(0.3 ).setRegParam(1.6 )
139246 val model = trainer.fit(dataset)
@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
161268 assert(prediction1 ~== prediction2 relTol 1E-5 )
162269 }
163270 }
271+
272+ test(" linear regression without intercept with ElasticNet regularization" ) {
273+ val trainer = (new LinearRegression ).setElasticNetParam(0.3 ).setRegParam(1.6 )
274+ .setFitIntercept(false )
275+ val model = trainer.fit(dataset)
276+
277+ /**
278+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
279+ * intercept=FALSE))
280+ * > weights
281+ * 3 x 1 sparse Matrix of class "dgCMatrix"
282+ * s0
283+ * (Intercept) .
284+ * as.numeric.dataM.V2. 5.673348
285+ * as.numeric.dataM.V3. 4.322251
286+ */
287+ val interceptR = 0.0
288+ val weightsR = Array (5.673348 , 4.322251 )
289+
290+ assert(model.intercept ~== interceptR relTol 1E-3 )
291+ assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-3 )
292+ assert(model.weights(1 ) ~== weightsR(1 ) relTol 1E-3 )
293+
294+ model.transform(dataset).select(" features" , " prediction" ).collect().foreach {
295+ case Row (features : DenseVector , prediction1 : Double ) =>
296+ val prediction2 =
297+ features(0 ) * model.weights(0 ) + features(1 ) * model.weights(1 ) + model.intercept
298+ assert(prediction1 ~== prediction2 relTol 1E-5 )
299+ }
300+ }
164301}
0 commit comments