Skip to content

Commit 86d6526

Browse files
LewuatheDB Tsai
authored andcommitted
[SPARK-11207] [ML] Add test cases for solver selection of LinearRegres…
…sion as followup. This is the follow up work of SPARK-10668. * Fix miner style issues. * Add test case for checking whether solver is selected properly. Author: Lewuathe <[email protected]> Author: lewuathe <[email protected]> Closes #9180 from Lewuathe/SPARK-11207.
1 parent eb59b94 commit 86d6526

File tree

2 files changed

+144
-82
lines changed

2 files changed

+144
-82
lines changed

mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,11 @@ object LinearDataGenerator {
7777
nPoints: Int,
7878
seed: Int,
7979
eps: Double = 0.1): Seq[LabeledPoint] = {
80-
generateLinearInput(intercept, weights,
81-
Array.fill[Double](weights.length)(0.0),
82-
Array.fill[Double](weights.length)(1.0 / 3.0),
83-
nPoints, seed, eps)}
80+
generateLinearInput(intercept, weights, Array.fill[Double](weights.length)(0.0),
81+
Array.fill[Double](weights.length)(1.0 / 3.0), nPoints, seed, eps)
82+
}
8483

8584
/**
86-
*
8785
* @param intercept Data intercept
8886
* @param weights Weights to be applied.
8987
* @param xMean the mean of the generated features. Lots of time, if the features are not properly
@@ -104,24 +102,66 @@ object LinearDataGenerator {
104102
nPoints: Int,
105103
seed: Int,
106104
eps: Double): Seq[LabeledPoint] = {
105+
generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps, 0.0)
106+
}
107+
107108

109+
/**
110+
* @param intercept Data intercept
111+
* @param weights Weights to be applied.
112+
* @param xMean the mean of the generated features. Lots of time, if the features are not properly
113+
* standardized, the algorithm with poor implementation will have difficulty
114+
* to converge.
115+
* @param xVariance the variance of the generated features.
116+
* @param nPoints Number of points in sample.
117+
* @param seed Random seed
118+
* @param eps Epsilon scaling factor.
119+
* @param sparsity The ratio of zero elements. If it is 0.0, LabeledPoints with
120+
* DenseVector is returned.
121+
* @return Seq of input.
122+
*/
123+
@Since("1.6.0")
124+
def generateLinearInput(
125+
intercept: Double,
126+
weights: Array[Double],
127+
xMean: Array[Double],
128+
xVariance: Array[Double],
129+
nPoints: Int,
130+
seed: Int,
131+
eps: Double,
132+
sparsity: Double): Seq[LabeledPoint] = {
133+
require(0.0 <= sparsity && sparsity <= 1.0)
108134
val rnd = new Random(seed)
109135
val x = Array.fill[Array[Double]](nPoints)(
110136
Array.fill[Double](weights.length)(rnd.nextDouble()))
111137

138+
val sparseRnd = new Random(seed)
112139
x.foreach { v =>
113140
var i = 0
114141
val len = v.length
115142
while (i < len) {
116-
v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
143+
if (sparseRnd.nextDouble() < sparsity) {
144+
v(i) = 0.0
145+
} else {
146+
v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
147+
}
117148
i += 1
118149
}
119150
}
120151

121152
val y = x.map { xi =>
122153
blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()
123154
}
124-
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
155+
156+
y.zip(x).map { p =>
157+
if (sparsity == 0.0) {
158+
// Return LabeledPoints with DenseVector
159+
LabeledPoint(p._1, Vectors.dense(p._2))
160+
} else {
161+
// Return LabeledPoints with SparseVector
162+
LabeledPoint(p._1, Vectors.dense(p._2).toSparse)
163+
}
164+
}
125165
}
126166

127167
/**

0 commit comments

Comments
 (0)