Skip to content

Commit 7015b9f

Browse files
committed
Our code performs the same with R, except we need more than one data point but that seems reasonable
1 parent 0b0c8c0 commit 7015b9f

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
3737
* val data =
3838
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
3939
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1).saveAsTextFile("path")
40-
* val dataNR =
41-
* sc.parallelize(LinearDataGenerator.generateLinearInput(0.0, Array(4.7, 7.2), 10000, 42), 2)
42-
* dataNR.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1).saveAsTextFile("pathNR")
40+
* val dataM =
41+
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), Array(0.9, -1.3),
42+
* Array(0.7, 1.2), 10000, 42, 0.1), 2)
43+
* dataM.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1).saveAsTextFile("pathM")
4344
*/
4445
override def beforeAll(): Unit = {
4546
super.beforeAll()
@@ -49,6 +50,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
4950
datasetNR = sqlContext.createDataFrame(
5051
sc.parallelize(LinearDataGenerator.generateLinearInput(
5152
0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
53+
5254
}
5355

5456
test("linear regression with intercept without regularization") {
@@ -102,10 +104,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
102104
* 3 x 1 sparse Matrix of class "dgCMatrix"
103105
* s0
104106
* (Intercept) .
105-
* as.numeric.data.V2. 4.648385
106-
* as.numeric.data.V3. 7.462729
107+
* as.numeric.data.V2. 6.995908
108+
* as.numeric.data.V3. 5.275131
107109
*/
108-
val weightsR = Array(4.648385, 7.462729)
110+
val weightsR = Array(6.995908, 5.275131)
109111

110112
assert(model.intercept ~== 0 relTol 1E-3)
111113
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
@@ -116,10 +118,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
116118
* 3 x 1 sparse Matrix of class "dgCMatrix"
117119
* s0
118120
* (Intercept) .
119-
* as.numeric.dataNR.V2. 4.701019
120-
* as.numeric.dataNR.V3. 7.198280
121+
* as.numeric.data3.V2. 4.70011
122+
* as.numeric.data3.V3. 7.19943
121123
*/
122-
val weightsRNR = Array(4.701019, 7.198280)
124+
val weightsRNR = Array(4.70011, 7.19943)
123125

124126
assert(modelNR.intercept ~== 0 relTol 1E-3)
125127
assert(modelNR.weights(0) ~== weightsRNR(0) relTol 1E-3)

0 commit comments

Comments
 (0)