Skip to content

Commit 312ea3f

Browse files
committed
[SPARK-17748][FOLLOW-UP][ML] Reorg variables of WeightedLeastSquares.
## What changes were proposed in this pull request? This is follow-up work of #15394. Reorg some variables of ```WeightedLeastSquares``` and fix one minor issue of ```WeightedLeastSquaresSuite```. ## How was this patch tested? Existing tests. Author: Yanbo Liang <[email protected]> Closes #15621 from yanboliang/spark-17748.
1 parent 4bee954 commit 312ea3f

File tree

2 files changed

+86
-68
lines changed

2 files changed

+86
-68
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 79 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares(
101101
summary.validate()
102102
logInfo(s"Number of instances: ${summary.count}.")
103103
val k = if (fitIntercept) summary.k + 1 else summary.k
104+
val numFeatures = summary.k
104105
val triK = summary.triK
105106
val wSum = summary.wSum
106-
val bBar = summary.bBar
107-
val bbBar = summary.bbBar
108-
val aBar = summary.aBar
109-
val aStd = summary.aStd
110-
val abBar = summary.abBar
111-
val aaBar = summary.aaBar
112-
val numFeatures = abBar.size
107+
113108
val rawBStd = summary.bStd
109+
val rawBBar = summary.bBar
114110
// if b is constant (rawBStd is zero), then b cannot be scaled. In this case
115-
// setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm.
116-
val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd
111+
// setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm.
112+
val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd
117113

118114
if (rawBStd == 0) {
119-
if (fitIntercept || bBar == 0.0) {
120-
if (bBar == 0.0) {
115+
if (fitIntercept || rawBBar == 0.0) {
116+
if (rawBBar == 0.0) {
121117
logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
122118
s"and the intercept will all be zero; as a result, training is not needed.")
123119
} else {
@@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares(
126122
s"training is not needed.")
127123
}
128124
val coefficients = new DenseVector(Array.ofDim(numFeatures))
129-
val intercept = bBar
125+
val intercept = rawBBar
130126
val diagInvAtWA = new DenseVector(Array(0D))
131127
return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D))
132128
} else {
@@ -137,65 +133,82 @@ private[ml] class WeightedLeastSquares(
137133
}
138134
}
139135

140-
// scale aBar to standardized space in-place
141-
val aBarValues = aBar.values
142-
var j = 0
143-
while (j < numFeatures) {
144-
if (aStd(j) == 0.0) {
145-
aBarValues(j) = 0.0
146-
} else {
147-
aBarValues(j) /= aStd(j)
148-
}
149-
j += 1
150-
}
136+
val bBar = summary.bBar / bStd
137+
val bbBar = summary.bbBar / (bStd * bStd)
151138

152-
// scale abBar to standardized space in-place
153-
val abBarValues = abBar.values
139+
val aStd = summary.aStd
154140
val aStdValues = aStd.values
155-
j = 0
156-
while (j < numFeatures) {
157-
if (aStdValues(j) == 0.0) {
158-
abBarValues(j) = 0.0
159-
} else {
160-
abBarValues(j) /= (aStdValues(j) * bStd)
141+
142+
val aBar = {
143+
val _aBar = summary.aBar
144+
val _aBarValues = _aBar.values
145+
var i = 0
146+
// scale aBar to standardized space in-place
147+
while (i < numFeatures) {
148+
if (aStdValues(i) == 0.0) {
149+
_aBarValues(i) = 0.0
150+
} else {
151+
_aBarValues(i) /= aStdValues(i)
152+
}
153+
i += 1
161154
}
162-
j += 1
155+
_aBar
163156
}
157+
val aBarValues = aBar.values
164158

165-
// scale aaBar to standardized space in-place
166-
val aaBarValues = aaBar.values
167-
j = 0
168-
var p = 0
169-
while (j < numFeatures) {
170-
val aStdJ = aStdValues(j)
159+
val abBar = {
160+
val _abBar = summary.abBar
161+
val _abBarValues = _abBar.values
171162
var i = 0
172-
while (i <= j) {
173-
val aStdI = aStdValues(i)
174-
if (aStdJ == 0.0 || aStdI == 0.0) {
175-
aaBarValues(p) = 0.0
163+
// scale abBar to standardized space in-place
164+
while (i < numFeatures) {
165+
if (aStdValues(i) == 0.0) {
166+
_abBarValues(i) = 0.0
176167
} else {
177-
aaBarValues(p) /= (aStdI * aStdJ)
168+
_abBarValues(i) /= (aStdValues(i) * bStd)
178169
}
179-
p += 1
180170
i += 1
181171
}
182-
j += 1
172+
_abBar
183173
}
174+
val abBarValues = abBar.values
184175

185-
val bBarStd = bBar / bStd
186-
val bbBarStd = bbBar / (bStd * bStd)
176+
val aaBar = {
177+
val _aaBar = summary.aaBar
178+
val _aaBarValues = _aaBar.values
179+
var j = 0
180+
var p = 0
181+
// scale aaBar to standardized space in-place
182+
while (j < numFeatures) {
183+
val aStdJ = aStdValues(j)
184+
var i = 0
185+
while (i <= j) {
186+
val aStdI = aStdValues(i)
187+
if (aStdJ == 0.0 || aStdI == 0.0) {
188+
_aaBarValues(p) = 0.0
189+
} else {
190+
_aaBarValues(p) /= (aStdI * aStdJ)
191+
}
192+
p += 1
193+
i += 1
194+
}
195+
j += 1
196+
}
197+
_aaBar
198+
}
199+
val aaBarValues = aaBar.values
187200

188201
val effectiveRegParam = regParam / bStd
189202
val effectiveL1RegParam = elasticNetParam * effectiveRegParam
190203
val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
191204

192205
// add L2 regularization to diagonals
193206
var i = 0
194-
j = 2
207+
var j = 2
195208
while (i < triK) {
196209
var lambda = effectiveL2RegParam
197210
if (!standardizeFeatures) {
198-
val std = aStd(j - 2)
211+
val std = aStdValues(j - 2)
199212
if (std != 0.0) {
200213
lambda /= (std * std)
201214
} else {
@@ -209,8 +222,9 @@ private[ml] class WeightedLeastSquares(
209222
i += j
210223
j += 1
211224
}
212-
val aa = getAtA(aaBar.values, aBar.values)
213-
val ab = getAtB(abBar.values, bBarStd)
225+
226+
val aa = getAtA(aaBarValues, aBarValues)
227+
val ab = getAtB(abBarValues, bBar)
214228

215229
val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
216230
regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
@@ -237,22 +251,23 @@ private[ml] class WeightedLeastSquares(
237251
val solution = solver match {
238252
case cholesky: CholeskySolver =>
239253
try {
240-
cholesky.solve(bBarStd, bbBarStd, ab, aa, aBar)
254+
cholesky.solve(bBar, bbBar, ab, aa, aBar)
241255
} catch {
242256
// if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
243-
// quasi-newton solver
257+
// Quasi-Newton solver.
244258
case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
245259
logWarning("Cholesky solver failed due to singular covariance matrix. " +
246260
"Retrying with Quasi-Newton solver.")
247261
// ab and aa were modified in place, so reconstruct them
248-
val _aa = getAtA(aaBar.values, aBar.values)
249-
val _ab = getAtB(abBar.values, bBarStd)
262+
val _aa = getAtA(aaBarValues, aBarValues)
263+
val _ab = getAtB(abBarValues, bBar)
250264
val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
251-
newSolver.solve(bBarStd, bbBarStd, _ab, _aa, aBar)
265+
newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
252266
}
253267
case qn: QuasiNewtonSolver =>
254-
qn.solve(bBarStd, bbBarStd, ab, aa, aBar)
268+
qn.solve(bBar, bbBar, ab, aa, aBar)
255269
}
270+
256271
val (coefficientArray, intercept) = if (fitIntercept) {
257272
(solution.coefficients.slice(0, solution.coefficients.length - 1),
258273
solution.coefficients.last * bStd)
@@ -271,7 +286,11 @@ private[ml] class WeightedLeastSquares(
271286
// aaInv is a packed upper triangular matrix, here we get all elements on diagonal
272287
val diagInvAtWA = solution.aaInv.map { inv =>
273288
new DenseVector((1 to k).map { i =>
274-
val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1) * aStdValues(i - 1)
289+
val multiplier = if (i == k && fitIntercept) {
290+
1.0
291+
} else {
292+
aStdValues(i - 1) * aStdValues(i - 1)
293+
}
275294
inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier)
276295
}.toArray)
277296
}.getOrElse(new DenseVector(Array(0D)))
@@ -280,7 +299,7 @@ private[ml] class WeightedLeastSquares(
280299
solution.objectiveHistory.getOrElse(Array(0D)))
281300
}
282301

283-
/** Construct A^T^ A from summary statistics. */
302+
/** Construct A^T^ A (append bias if necessary). */
284303
private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = {
285304
if (fitIntercept) {
286305
new DenseVector(Array.concat(aaBar, aBar, Array(1.0)))
@@ -289,7 +308,7 @@ private[ml] class WeightedLeastSquares(
289308
}
290309
}
291310

292-
/** Construct A^T^ b from summary statistics. */
311+
/** Construct A^T^ b (append bias if necessary). */
293312
private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = {
294313
if (fitIntercept) {
295314
new DenseVector(Array.concat(abBar, Array(bBar)))

mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,13 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
361361
for (fitIntercept <- Seq(false, true);
362362
standardization <- Seq(false, true);
363363
(lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) {
364-
for (solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.Cholesky)) {
365-
val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha,
366-
standardizeFeatures = standardization, standardizeLabel = true,
367-
solverType = WeightedLeastSquares.QuasiNewton)
368-
val model = wls.fit(constantFeaturesInstances)
369-
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
370-
assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6)
371-
}
364+
val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha,
365+
standardizeFeatures = standardization, standardizeLabel = true,
366+
solverType = WeightedLeastSquares.QuasiNewton)
367+
val model = wls.fit(constantFeaturesInstances)
368+
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
369+
assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6)
370+
372371
idx += 1
373372
}
374373
}

0 commit comments

Comments
 (0)