Skip to content

Commit 99ced09

Browse files
committed
Reorg variables of WeightedLeastSquares.
1 parent 78d740a commit 99ced09

File tree

2 files changed

+79
-71
lines changed

2 files changed

+79
-71
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 72 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares(
101101
summary.validate()
102102
logInfo(s"Number of instances: ${summary.count}.")
103103
val k = if (fitIntercept) summary.k + 1 else summary.k
104+
val numFeatures = summary.k
104105
val triK = summary.triK
105106
val wSum = summary.wSum
106-
val bBar = summary.bBar
107-
val bbBar = summary.bbBar
108-
val aBar = summary.aBar
109-
val aStd = summary.aStd
110-
val abBar = summary.abBar
111-
val aaBar = summary.aaBar
112-
val numFeatures = abBar.size
107+
113108
val rawBStd = summary.bStd
109+
val rawBBar = summary.bBar
114110
// if b is constant (rawBStd is zero), then b cannot be scaled. In this case
115-
// setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm.
116-
val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd
111+
// setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm.
112+
val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd
117113

118114
if (rawBStd == 0) {
119-
if (fitIntercept || bBar == 0.0) {
120-
if (bBar == 0.0) {
115+
if (fitIntercept || rawBBar == 0.0) {
116+
if (rawBBar == 0.0) {
121117
logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
122118
s"and the intercept will all be zero; as a result, training is not needed.")
123119
} else {
@@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares(
126122
s"training is not needed.")
127123
}
128124
val coefficients = new DenseVector(Array.ofDim(numFeatures))
129-
val intercept = bBar
125+
val intercept = rawBBar
130126
val diagInvAtWA = new DenseVector(Array(0D))
131127
return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D))
132128
} else {
@@ -137,61 +133,68 @@ private[ml] class WeightedLeastSquares(
137133
}
138134
}
139135

140-
// scale aBar to standardized space in-place
141-
val aBarValues = aBar.values
142-
var j = 0
143-
while (j < numFeatures) {
144-
if (aStd(j) == 0.0) {
145-
aBarValues(j) = 0.0
146-
} else {
147-
aBarValues(j) /= aStd(j)
148-
}
149-
j += 1
150-
}
136+
val bBar = summary.bBar / bStd
137+
val bbBar = summary.bbBar / (bStd * bStd)
151138

152-
// scale abBar to standardized space in-place
153-
val abBarValues = abBar.values
154-
val aStdValues = aStd.values
155-
j = 0
156-
while (j < numFeatures) {
157-
if (aStdValues(j) == 0.0) {
158-
abBarValues(j) = 0.0
159-
} else {
160-
abBarValues(j) /= (aStdValues(j) * bStd)
139+
val aStd = summary.aStd
140+
val aBar = {
141+
val _aBar = summary.aBar
142+
var i = 0
143+
// scale aBar to standardized space in-place
144+
while (i < numFeatures) {
145+
if (aStd(i) == 0.0) {
146+
_aBar.values(i) = 0.0
147+
} else {
148+
_aBar.values(i) /= aStd(i)
149+
}
150+
i += 1
161151
}
162-
j += 1
152+
_aBar
163153
}
164-
165-
// scale aaBar to standardized space in-place
166-
val aaBarValues = aaBar.values
167-
j = 0
168-
var p = 0
169-
while (j < numFeatures) {
170-
val aStdJ = aStdValues(j)
154+
val abBar = {
155+
val _abBar = summary.abBar
171156
var i = 0
172-
while (i <= j) {
173-
val aStdI = aStdValues(i)
174-
if (aStdJ == 0.0 || aStdI == 0.0) {
175-
aaBarValues(p) = 0.0
157+
// scale abBar to standardized space in-place
158+
while (i < numFeatures) {
159+
if (aStd(i) == 0.0) {
160+
_abBar.values(i) = 0.0
176161
} else {
177-
aaBarValues(p) /= (aStdI * aStdJ)
162+
_abBar.values(i) /= (aStd(i) * bStd)
178163
}
179-
p += 1
180164
i += 1
181165
}
182-
j += 1
166+
_abBar
167+
}
168+
val aaBar = {
169+
val _aaBar = summary.aaBar
170+
var j = 0
171+
var p = 0
172+
// scale aaBar to standardized space in-place
173+
while (j < numFeatures) {
174+
val aStdJ = aStd.values(j)
175+
var i = 0
176+
while (i <= j) {
177+
val aStdI = aStd.values(i)
178+
if (aStdJ == 0.0 || aStdI == 0.0) {
179+
_aaBar.values(p) = 0.0
180+
} else {
181+
_aaBar.values(p) /= (aStdI * aStdJ)
182+
}
183+
p += 1
184+
i += 1
185+
}
186+
j += 1
187+
}
188+
_aaBar
183189
}
184-
185-
val bBarStd = bBar / bStd
186-
val bbBarStd = bbBar / (bStd * bStd)
187190

188191
val effectiveRegParam = regParam / bStd
189192
val effectiveL1RegParam = elasticNetParam * effectiveRegParam
190193
val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
191194

192195
// add L2 regularization to diagonals
193196
var i = 0
194-
j = 2
197+
var j = 2
195198
while (i < triK) {
196199
var lambda = effectiveL2RegParam
197200
if (!standardizeFeatures) {
@@ -205,12 +208,13 @@ private[ml] class WeightedLeastSquares(
205208
if (!standardizeLabel) {
206209
lambda *= bStd
207210
}
208-
aaBarValues(i) += lambda
211+
aaBar.values(i) += lambda
209212
i += j
210213
j += 1
211214
}
215+
212216
val aa = getAtA(aaBar.values, aBar.values)
213-
val ab = getAtB(abBar.values, bBarStd)
217+
val ab = getAtB(abBar.values, bBar)
214218

215219
val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
216220
regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
@@ -222,7 +226,7 @@ private[ml] class WeightedLeastSquares(
222226
if (standardizeFeatures) {
223227
effectiveL1RegParam
224228
} else {
225-
if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0
229+
if (aStd.values(index) != 0.0) effectiveL1RegParam / aStd.values(index) else 0.0
226230
}
227231
}
228232
})
@@ -237,22 +241,23 @@ private[ml] class WeightedLeastSquares(
237241
val solution = solver match {
238242
case cholesky: CholeskySolver =>
239243
try {
240-
cholesky.solve(bBarStd, bbBarStd, ab, aa, aBar)
244+
cholesky.solve(bBar, bbBar, ab, aa, aBar)
241245
} catch {
242246
// if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
243-
// quasi-newton solver
247+
// Quasi-Newton solver.
244248
case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
245249
logWarning("Cholesky solver failed due to singular covariance matrix. " +
246250
"Retrying with Quasi-Newton solver.")
247251
// ab and aa were modified in place, so reconstruct them
248252
val _aa = getAtA(aaBar.values, aBar.values)
249-
val _ab = getAtB(abBar.values, bBarStd)
253+
val _ab = getAtB(abBar.values, bBar)
250254
val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
251-
newSolver.solve(bBarStd, bbBarStd, _ab, _aa, aBar)
255+
newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
252256
}
253257
case qn: QuasiNewtonSolver =>
254-
qn.solve(bBarStd, bbBarStd, ab, aa, aBar)
258+
qn.solve(bBar, bbBar, ab, aa, aBar)
255259
}
260+
256261
val (coefficientArray, intercept) = if (fitIntercept) {
257262
(solution.coefficients.slice(0, solution.coefficients.length - 1),
258263
solution.coefficients.last * bStd)
@@ -264,14 +269,18 @@ private[ml] class WeightedLeastSquares(
264269
var q = 0
265270
val len = coefficientArray.length
266271
while (q < len) {
267-
coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 }
272+
coefficientArray(q) *= { if (aStd.values(q) != 0.0) bStd / aStd.values(q) else 0.0 }
268273
q += 1
269274
}
270275

271276
// aaInv is a packed upper triangular matrix, here we get all elements on diagonal
272277
val diagInvAtWA = solution.aaInv.map { inv =>
273278
new DenseVector((1 to k).map { i =>
274-
val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1) * aStdValues(i - 1)
279+
val multiplier = if (i == k && fitIntercept) {
280+
1.0
281+
} else {
282+
aStd.values(i - 1) * aStd.values(i - 1)
283+
}
275284
inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier)
276285
}.toArray)
277286
}.getOrElse(new DenseVector(Array(0D)))
@@ -280,7 +289,7 @@ private[ml] class WeightedLeastSquares(
280289
solution.objectiveHistory.getOrElse(Array(0D)))
281290
}
282291

283-
/** Construct A^T^ A from summary statistics. */
292+
/** Construct A^T^ A (append bias if necessary). */
284293
private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = {
285294
if (fitIntercept) {
286295
new DenseVector(Array.concat(aaBar, aBar, Array(1.0)))
@@ -289,7 +298,7 @@ private[ml] class WeightedLeastSquares(
289298
}
290299
}
291300

292-
/** Construct A^T^ b from summary statistics. */
301+
/** Construct A^T^ b (append bias if necessary). */
293302
private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = {
294303
if (fitIntercept) {
295304
new DenseVector(Array.concat(abBar, Array(bBar)))

mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,13 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
361361
for (fitIntercept <- Seq(false, true);
362362
standardization <- Seq(false, true);
363363
(lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) {
364-
for (solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.Cholesky)) {
365-
val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha,
366-
standardizeFeatures = standardization, standardizeLabel = true,
367-
solverType = WeightedLeastSquares.QuasiNewton)
368-
val model = wls.fit(constantFeaturesInstances)
369-
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
370-
assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6)
371-
}
364+
val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha,
365+
standardizeFeatures = standardization, standardizeLabel = true,
366+
solverType = WeightedLeastSquares.QuasiNewton)
367+
val model = wls.fit(constantFeaturesInstances)
368+
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
369+
assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6)
370+
372371
idx += 1
373372
}
374373
}

0 commit comments

Comments
 (0)