Skip to content

Commit 314b562

Browse files
committed
Fix test issues
1 parent c05a948 commit 314b562

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,19 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
5252
datasetGaussianIdentity = sqlContext.createDataFrame(
5353
sc.parallelize(generateGeneralizedLinearRegressionInput(
5454
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
55-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
55+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
5656
family = "gaussian", link = "identity"), 2))
5757

5858
datasetGaussianLog = sqlContext.createDataFrame(
5959
sc.parallelize(generateGeneralizedLinearRegressionInput(
6060
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
61-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
61+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
6262
family = "gaussian", link = "log"), 2))
6363

6464
datasetGaussianInverse = sqlContext.createDataFrame(
6565
sc.parallelize(generateGeneralizedLinearRegressionInput(
6666
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
67-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
67+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
6868
family = "gaussian", link = "inverse"), 2))
6969

7070
datasetBinomial = {
@@ -74,45 +74,46 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
7474
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
7575

7676
val testData =
77-
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, seed)
77+
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
78+
addIntercept = true, nPoints, seed)
7879

79-
sqlContext.createDataFrame(sc.parallelize(testData, 4))
80+
sqlContext.createDataFrame(sc.parallelize(testData, 2))
8081
}
8182

8283
datasetPoissonLog = sqlContext.createDataFrame(
8384
sc.parallelize(generateGeneralizedLinearRegressionInput(
8485
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
85-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
86+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
8687
family = "poisson", link = "log"), 2))
8788

8889
datasetPoissonIdentity = sqlContext.createDataFrame(
8990
sc.parallelize(generateGeneralizedLinearRegressionInput(
9091
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
91-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
92+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
9293
family = "poisson", link = "identity"), 2))
9394

9495
datasetPoissonSqrt = sqlContext.createDataFrame(
9596
sc.parallelize(generateGeneralizedLinearRegressionInput(
9697
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
97-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
98+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
9899
family = "poisson", link = "sqrt"), 2))
99100

100101
datasetGammaInverse = sqlContext.createDataFrame(
101102
sc.parallelize(generateGeneralizedLinearRegressionInput(
102103
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
103-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
104+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
104105
family = "gamma", link = "inverse"), 2))
105106

106107
datasetGammaIdentity = sqlContext.createDataFrame(
107108
sc.parallelize(generateGeneralizedLinearRegressionInput(
108109
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
109-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
110+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
110111
family = "gamma", link = "identity"), 2))
111112

112113
datasetGammaLog = sqlContext.createDataFrame(
113114
sc.parallelize(generateGeneralizedLinearRegressionInput(
114115
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
115-
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.01,
116+
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
116117
family = "gamma", link = "log"), 2))
117118
}
118119

@@ -132,14 +133,12 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
132133
assert(glr.getWeightCol === "")
133134
assert(glr.getRegParam === 0.0)
134135
assert(glr.getSolver == "irls")
136+
// TODO: Construct model directly instead of via fitting.
135137
val model = glr.setFamily("gaussian").setLink("identity")
136138
.fit(datasetGaussianIdentity)
137139

138140
// copied model must have the same parent.
139141
MLTestingUtils.checkCopy(model)
140-
model.transform(datasetGaussianIdentity)
141-
.select("label", "prediction")
142-
.collect()
143142

144143
assert(model.getFeaturesCol === "features")
145144
assert(model.getPredictionCol === "prediction")
@@ -467,7 +466,7 @@ object GeneralizedLinearRegressionSuite {
467466
xVariance: Array[Double],
468467
nPoints: Int,
469468
seed: Int,
470-
eps: Double,
469+
noiseLevel: Double,
471470
family: String,
472471
link: String): Seq[LabeledPoint] = {
473472

@@ -491,7 +490,7 @@ object GeneralizedLinearRegressionSuite {
491490
case "sqrt" => math.pow(eta, 2.0)
492491
case "inverse" => 1.0 / eta
493492
}
494-
val label = mu + eps * (generator.nextValue() - mean)
493+
val label = mu + noiseLevel * (generator.nextValue() - mean)
495494
// Return LabeledPoints with DenseVector
496495
LabeledPoint(label, features)
497496
}

0 commit comments

Comments
 (0)