@@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares(
101101 summary.validate()
102102 logInfo(s " Number of instances: ${summary.count}. " )
103103 val k = if (fitIntercept) summary.k + 1 else summary.k
104+ val numFeatures = summary.k
104105 val triK = summary.triK
105106 val wSum = summary.wSum
106- val bBar = summary.bBar
107- val bbBar = summary.bbBar
108- val aBar = summary.aBar
109- val aStd = summary.aStd
110- val abBar = summary.abBar
111- val aaBar = summary.aaBar
112- val numFeatures = abBar.size
107+
113108 val rawBStd = summary.bStd
109+ val rawBBar = summary.bBar
114110 // if b is constant (rawBStd is zero), then b cannot be scaled. In this case
115- // setting bStd=abs(bBar ) ensures that b is not scaled anymore in l-bfgs algorithm.
116- val bStd = if (rawBStd == 0.0 ) math.abs(bBar ) else rawBStd
111+ // setting bStd=abs(rawBBar ) ensures that b is not scaled anymore in l-bfgs algorithm.
112+ val bStd = if (rawBStd == 0.0 ) math.abs(rawBBar ) else rawBStd
117113
118114 if (rawBStd == 0 ) {
119- if (fitIntercept || bBar == 0.0 ) {
120- if (bBar == 0.0 ) {
115+ if (fitIntercept || rawBBar == 0.0 ) {
116+ if (rawBBar == 0.0 ) {
121117 logWarning(s " Mean and standard deviation of the label are zero, so the coefficients " +
122118 s " and the intercept will all be zero; as a result, training is not needed. " )
123119 } else {
@@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares(
126122 s " training is not needed. " )
127123 }
128124 val coefficients = new DenseVector (Array .ofDim(numFeatures))
129- val intercept = bBar
125+ val intercept = rawBBar
130126 val diagInvAtWA = new DenseVector (Array (0D ))
131127 return new WeightedLeastSquaresModel (coefficients, intercept, diagInvAtWA, Array (0D ))
132128 } else {
@@ -137,61 +133,68 @@ private[ml] class WeightedLeastSquares(
137133 }
138134 }
139135
140- // scale aBar to standardized space in-place
141- val aBarValues = aBar.values
142- var j = 0
143- while (j < numFeatures) {
144- if (aStd(j) == 0.0 ) {
145- aBarValues(j) = 0.0
146- } else {
147- aBarValues(j) /= aStd(j)
148- }
149- j += 1
150- }
136+ val bBar = summary.bBar / bStd
137+ val bbBar = summary.bbBar / (bStd * bStd)
151138
152- // scale abBar to standardized space in-place
153- val abBarValues = abBar.values
154- val aStdValues = aStd.values
155- j = 0
156- while (j < numFeatures) {
157- if (aStdValues(j) == 0.0 ) {
158- abBarValues(j) = 0.0
159- } else {
160- abBarValues(j) /= (aStdValues(j) * bStd)
139+ val aStd = summary.aStd
140+ val aBar = {
141+ val _aBar = summary.aBar
142+ var i = 0
143+ // scale aBar to standardized space in-place
144+ while (i < numFeatures) {
145+ if (aStd(i) == 0.0 ) {
146+ _aBar.values(i) = 0.0
147+ } else {
148+ _aBar.values(i) /= aStd(i)
149+ }
150+ i += 1
161151 }
162- j += 1
152+ _aBar
163153 }
164-
165- // scale aaBar to standardized space in-place
166- val aaBarValues = aaBar.values
167- j = 0
168- var p = 0
169- while (j < numFeatures) {
170- val aStdJ = aStdValues(j)
154+ val abBar = {
155+ val _abBar = summary.abBar
171156 var i = 0
172- while (i <= j) {
173- val aStdI = aStdValues(i)
174- if (aStdJ == 0.0 || aStdI == 0.0 ) {
175- aaBarValues(p ) = 0.0
157+ // scale abBar to standardized space in-place
158+ while (i < numFeatures) {
159+ if (aStd(i) == 0.0 ) {
160+ _abBar.values(i ) = 0.0
176161 } else {
177- aaBarValues(p ) /= (aStdI * aStdJ )
162+ _abBar.values(i ) /= (aStd(i) * bStd )
178163 }
179- p += 1
180164 i += 1
181165 }
182- j += 1
166+ _abBar
167+ }
168+ val aaBar = {
169+ val _aaBar = summary.aaBar
170+ var j = 0
171+ var p = 0
172+ // scale aaBar to standardized space in-place
173+ while (j < numFeatures) {
174+ val aStdJ = aStd.values(j)
175+ var i = 0
176+ while (i <= j) {
177+ val aStdI = aStd.values(i)
178+ if (aStdJ == 0.0 || aStdI == 0.0 ) {
179+ _aaBar.values(p) = 0.0
180+ } else {
181+ _aaBar.values(p) /= (aStdI * aStdJ)
182+ }
183+ p += 1
184+ i += 1
185+ }
186+ j += 1
187+ }
188+ _aaBar
183189 }
184-
185- val bBarStd = bBar / bStd
186- val bbBarStd = bbBar / (bStd * bStd)
187190
188191 val effectiveRegParam = regParam / bStd
189192 val effectiveL1RegParam = elasticNetParam * effectiveRegParam
190193 val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
191194
192195 // add L2 regularization to diagonals
193196 var i = 0
194- j = 2
197+ var j = 2
195198 while (i < triK) {
196199 var lambda = effectiveL2RegParam
197200 if (! standardizeFeatures) {
@@ -205,12 +208,13 @@ private[ml] class WeightedLeastSquares(
205208 if (! standardizeLabel) {
206209 lambda *= bStd
207210 }
208- aaBarValues (i) += lambda
211+ aaBar.values (i) += lambda
209212 i += j
210213 j += 1
211214 }
215+
212216 val aa = getAtA(aaBar.values, aBar.values)
213- val ab = getAtB(abBar.values, bBarStd )
217+ val ab = getAtB(abBar.values, bBar )
214218
215219 val solver = if ((solverType == WeightedLeastSquares .Auto && elasticNetParam != 0.0 &&
216220 regParam != 0.0 ) || (solverType == WeightedLeastSquares .QuasiNewton )) {
@@ -222,7 +226,7 @@ private[ml] class WeightedLeastSquares(
222226 if (standardizeFeatures) {
223227 effectiveL1RegParam
224228 } else {
225- if (aStdValues (index) != 0.0 ) effectiveL1RegParam / aStdValues (index) else 0.0
229+ if (aStd.values (index) != 0.0 ) effectiveL1RegParam / aStd.values (index) else 0.0
226230 }
227231 }
228232 })
@@ -237,22 +241,23 @@ private[ml] class WeightedLeastSquares(
237241 val solution = solver match {
238242 case cholesky : CholeskySolver =>
239243 try {
240- cholesky.solve(bBarStd, bbBarStd , ab, aa, aBar)
244+ cholesky.solve(bBar, bbBar , ab, aa, aBar)
241245 } catch {
242246 // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
243- // quasi-newton solver
247+ // Quasi-Newton solver.
244248 case _ : SingularMatrixException if solverType == WeightedLeastSquares .Auto =>
245249 logWarning(" Cholesky solver failed due to singular covariance matrix. " +
246250 " Retrying with Quasi-Newton solver." )
247251 // ab and aa were modified in place, so reconstruct them
248252 val _aa = getAtA(aaBar.values, aBar.values)
249- val _ab = getAtB(abBar.values, bBarStd )
253+ val _ab = getAtB(abBar.values, bBar )
250254 val newSolver = new QuasiNewtonSolver (fitIntercept, maxIter, tol, None )
251- newSolver.solve(bBarStd, bbBarStd , _ab, _aa, aBar)
255+ newSolver.solve(bBar, bbBar , _ab, _aa, aBar)
252256 }
253257 case qn : QuasiNewtonSolver =>
254- qn.solve(bBarStd, bbBarStd , ab, aa, aBar)
258+ qn.solve(bBar, bbBar , ab, aa, aBar)
255259 }
260+
256261 val (coefficientArray, intercept) = if (fitIntercept) {
257262 (solution.coefficients.slice(0 , solution.coefficients.length - 1 ),
258263 solution.coefficients.last * bStd)
@@ -264,14 +269,18 @@ private[ml] class WeightedLeastSquares(
264269 var q = 0
265270 val len = coefficientArray.length
266271 while (q < len) {
267- coefficientArray(q) *= { if (aStdValues (q) != 0.0 ) bStd / aStdValues (q) else 0.0 }
272+ coefficientArray(q) *= { if (aStd.values (q) != 0.0 ) bStd / aStd.values (q) else 0.0 }
268273 q += 1
269274 }
270275
271276 // aaInv is a packed upper triangular matrix, here we get all elements on diagonal
272277 val diagInvAtWA = solution.aaInv.map { inv =>
273278 new DenseVector ((1 to k).map { i =>
274- val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1 ) * aStdValues(i - 1 )
279+ val multiplier = if (i == k && fitIntercept) {
280+ 1.0
281+ } else {
282+ aStd.values(i - 1 ) * aStd.values(i - 1 )
283+ }
275284 inv(i + (i - 1 ) * i / 2 - 1 ) / (wSum * multiplier)
276285 }.toArray)
277286 }.getOrElse(new DenseVector (Array (0D )))
@@ -280,7 +289,7 @@ private[ml] class WeightedLeastSquares(
280289 solution.objectiveHistory.getOrElse(Array (0D )))
281290 }
282291
283- /** Construct A^T^ A from summary statistics . */
292+ /** Construct A^T^ A (append bias if necessary) . */
284293 private def getAtA (aaBar : Array [Double ], aBar : Array [Double ]): DenseVector = {
285294 if (fitIntercept) {
286295 new DenseVector (Array .concat(aaBar, aBar, Array (1.0 )))
@@ -289,7 +298,7 @@ private[ml] class WeightedLeastSquares(
289298 }
290299 }
291300
292- /** Construct A^T^ b from summary statistics . */
301+ /** Construct A^T^ b (append bias if necessary) . */
293302 private def getAtB (abBar : Array [Double ], bBar : Double ): DenseVector = {
294303 if (fitIntercept) {
295304 new DenseVector (Array .concat(abBar, Array (bBar)))
0 commit comments