Skip to content

Commit 4b76649

Browse files
committed
Make explicit pointer to values in DenseVector.
1 parent 99ced09 commit 4b76649

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -137,48 +137,57 @@ private[ml] class WeightedLeastSquares(
137137
val bbBar = summary.bbBar / (bStd * bStd)
138138

139139
val aStd = summary.aStd
140+
val aStdValues = aStd.values
141+
140142
val aBar = {
141143
val _aBar = summary.aBar
144+
val _aBarValues = _aBar.values
142145
var i = 0
143146
// scale aBar to standardized space in-place
144147
while (i < numFeatures) {
145-
if (aStd(i) == 0.0) {
146-
_aBar.values(i) = 0.0
148+
if (aStdValues(i) == 0.0) {
149+
_aBarValues(i) = 0.0
147150
} else {
148-
_aBar.values(i) /= aStd(i)
151+
_aBarValues(i) /= aStdValues(i)
149152
}
150153
i += 1
151154
}
152155
_aBar
153156
}
157+
val aBarValues = aBar.values
158+
154159
val abBar = {
155160
val _abBar = summary.abBar
161+
val _abBarValues = _abBar.values
156162
var i = 0
157163
// scale abBar to standardized space in-place
158164
while (i < numFeatures) {
159-
if (aStd(i) == 0.0) {
160-
_abBar.values(i) = 0.0
165+
if (aStdValues(i) == 0.0) {
166+
_abBarValues(i) = 0.0
161167
} else {
162-
_abBar.values(i) /= (aStd(i) * bStd)
168+
_abBarValues(i) /= (aStdValues(i) * bStd)
163169
}
164170
i += 1
165171
}
166172
_abBar
167173
}
174+
val abBarValues = abBar.values
175+
168176
val aaBar = {
169177
val _aaBar = summary.aaBar
178+
val _aaBarValues = _aaBar.values
170179
var j = 0
171180
var p = 0
172181
// scale aaBar to standardized space in-place
173182
while (j < numFeatures) {
174-
val aStdJ = aStd.values(j)
183+
val aStdJ = aStdValues(j)
175184
var i = 0
176185
while (i <= j) {
177-
val aStdI = aStd.values(i)
186+
val aStdI = aStdValues(i)
178187
if (aStdJ == 0.0 || aStdI == 0.0) {
179-
_aaBar.values(p) = 0.0
188+
_aaBarValues(p) = 0.0
180189
} else {
181-
_aaBar.values(p) /= (aStdI * aStdJ)
190+
_aaBarValues(p) /= (aStdI * aStdJ)
182191
}
183192
p += 1
184193
i += 1
@@ -187,6 +196,7 @@ private[ml] class WeightedLeastSquares(
187196
}
188197
_aaBar
189198
}
199+
val aaBarValues = aaBar.values
190200

191201
val effectiveRegParam = regParam / bStd
192202
val effectiveL1RegParam = elasticNetParam * effectiveRegParam
@@ -198,7 +208,7 @@ private[ml] class WeightedLeastSquares(
198208
while (i < triK) {
199209
var lambda = effectiveL2RegParam
200210
if (!standardizeFeatures) {
201-
val std = aStd(j - 2)
211+
val std = aStdValues(j - 2)
202212
if (std != 0.0) {
203213
lambda /= (std * std)
204214
} else {
@@ -208,13 +218,13 @@ private[ml] class WeightedLeastSquares(
208218
if (!standardizeLabel) {
209219
lambda *= bStd
210220
}
211-
aaBar.values(i) += lambda
221+
aaBarValues(i) += lambda
212222
i += j
213223
j += 1
214224
}
215225

216-
val aa = getAtA(aaBar.values, aBar.values)
217-
val ab = getAtB(abBar.values, bBar)
226+
val aa = getAtA(aaBarValues, aBarValues)
227+
val ab = getAtB(abBarValues, bBar)
218228

219229
val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
220230
regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
@@ -226,7 +236,7 @@ private[ml] class WeightedLeastSquares(
226236
if (standardizeFeatures) {
227237
effectiveL1RegParam
228238
} else {
229-
if (aStd.values(index) != 0.0) effectiveL1RegParam / aStd.values(index) else 0.0
239+
if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0
230240
}
231241
}
232242
})
@@ -249,8 +259,8 @@ private[ml] class WeightedLeastSquares(
249259
logWarning("Cholesky solver failed due to singular covariance matrix. " +
250260
"Retrying with Quasi-Newton solver.")
251261
// ab and aa were modified in place, so reconstruct them
252-
val _aa = getAtA(aaBar.values, aBar.values)
253-
val _ab = getAtB(abBar.values, bBar)
262+
val _aa = getAtA(aaBarValues, aBarValues)
263+
val _ab = getAtB(abBarValues, bBar)
254264
val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
255265
newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
256266
}
@@ -269,7 +279,7 @@ private[ml] class WeightedLeastSquares(
269279
var q = 0
270280
val len = coefficientArray.length
271281
while (q < len) {
272-
coefficientArray(q) *= { if (aStd.values(q) != 0.0) bStd / aStd.values(q) else 0.0 }
282+
coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 }
273283
q += 1
274284
}
275285

@@ -279,7 +289,7 @@ private[ml] class WeightedLeastSquares(
279289
val multiplier = if (i == k && fitIntercept) {
280290
1.0
281291
} else {
282-
aStd.values(i - 1) * aStd.values(i - 1)
292+
aStdValues(i - 1) * aStdValues(i - 1)
283293
}
284294
inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier)
285295
}.toArray)

0 commit comments

Comments
 (0)