@@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares(
101101 summary.validate()
102102 logInfo(s " Number of instances: ${summary.count}. " )
103103 val k = if (fitIntercept) summary.k + 1 else summary.k
104+ val numFeatures = summary.k
104105 val triK = summary.triK
105106 val wSum = summary.wSum
106- val bBar = summary.bBar
107- val bbBar = summary.bbBar
108- val aBar = summary.aBar
109- val aStd = summary.aStd
110- val abBar = summary.abBar
111- val aaBar = summary.aaBar
112- val numFeatures = abBar.size
107+
113108 val rawBStd = summary.bStd
109+ val rawBBar = summary.bBar
114110 // if b is constant (rawBStd is zero), then b cannot be scaled. In this case
115- // setting bStd=abs(bBar ) ensures that b is not scaled anymore in l-bfgs algorithm.
116- val bStd = if (rawBStd == 0.0 ) math.abs(bBar ) else rawBStd
111+ // setting bStd=abs(rawBBar ) ensures that b is not scaled anymore in l-bfgs algorithm.
112+ val bStd = if (rawBStd == 0.0 ) math.abs(rawBBar ) else rawBStd
117113
118114 if (rawBStd == 0 ) {
119- if (fitIntercept || bBar == 0.0 ) {
120- if (bBar == 0.0 ) {
115+ if (fitIntercept || rawBBar == 0.0 ) {
116+ if (rawBBar == 0.0 ) {
121117 logWarning(s " Mean and standard deviation of the label are zero, so the coefficients " +
122118 s " and the intercept will all be zero; as a result, training is not needed. " )
123119 } else {
@@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares(
126122 s " training is not needed. " )
127123 }
128124 val coefficients = new DenseVector (Array .ofDim(numFeatures))
129- val intercept = bBar
125+ val intercept = rawBBar
130126 val diagInvAtWA = new DenseVector (Array (0D ))
131127 return new WeightedLeastSquaresModel (coefficients, intercept, diagInvAtWA, Array (0D ))
132128 } else {
@@ -137,65 +133,82 @@ private[ml] class WeightedLeastSquares(
137133 }
138134 }
139135
140- // scale aBar to standardized space in-place
141- val aBarValues = aBar.values
142- var j = 0
143- while (j < numFeatures) {
144- if (aStd(j) == 0.0 ) {
145- aBarValues(j) = 0.0
146- } else {
147- aBarValues(j) /= aStd(j)
148- }
149- j += 1
150- }
136+ val bBar = summary.bBar / bStd
137+ val bbBar = summary.bbBar / (bStd * bStd)
151138
152- // scale abBar to standardized space in-place
153- val abBarValues = abBar.values
139+ val aStd = summary.aStd
154140 val aStdValues = aStd.values
155- j = 0
156- while (j < numFeatures) {
157- if (aStdValues(j) == 0.0 ) {
158- abBarValues(j) = 0.0
159- } else {
160- abBarValues(j) /= (aStdValues(j) * bStd)
141+
142+ val aBar = {
143+ val _aBar = summary.aBar
144+ val _aBarValues = _aBar.values
145+ var i = 0
146+ // scale aBar to standardized space in-place
147+ while (i < numFeatures) {
148+ if (aStdValues(i) == 0.0 ) {
149+ _aBarValues(i) = 0.0
150+ } else {
151+ _aBarValues(i) /= aStdValues(i)
152+ }
153+ i += 1
161154 }
162- j += 1
155+ _aBar
163156 }
157+ val aBarValues = aBar.values
164158
165- // scale aaBar to standardized space in-place
166- val aaBarValues = aaBar.values
167- j = 0
168- var p = 0
169- while (j < numFeatures) {
170- val aStdJ = aStdValues(j)
159+ val abBar = {
160+ val _abBar = summary.abBar
161+ val _abBarValues = _abBar.values
171162 var i = 0
172- while (i <= j) {
173- val aStdI = aStdValues(i)
174- if (aStdJ == 0.0 || aStdI == 0.0 ) {
175- aaBarValues(p ) = 0.0
163+ // scale abBar to standardized space in-place
164+ while (i < numFeatures) {
165+ if (aStdValues(i) == 0.0 ) {
166+ _abBarValues(i ) = 0.0
176167 } else {
177- aaBarValues(p ) /= (aStdI * aStdJ )
168+ _abBarValues(i ) /= (aStdValues(i) * bStd )
178169 }
179- p += 1
180170 i += 1
181171 }
182- j += 1
172+ _abBar
183173 }
174+ val abBarValues = abBar.values
184175
185- val bBarStd = bBar / bStd
186- val bbBarStd = bbBar / (bStd * bStd)
176+ val aaBar = {
177+ val _aaBar = summary.aaBar
178+ val _aaBarValues = _aaBar.values
179+ var j = 0
180+ var p = 0
181+ // scale aaBar to standardized space in-place
182+ while (j < numFeatures) {
183+ val aStdJ = aStdValues(j)
184+ var i = 0
185+ while (i <= j) {
186+ val aStdI = aStdValues(i)
187+ if (aStdJ == 0.0 || aStdI == 0.0 ) {
188+ _aaBarValues(p) = 0.0
189+ } else {
190+ _aaBarValues(p) /= (aStdI * aStdJ)
191+ }
192+ p += 1
193+ i += 1
194+ }
195+ j += 1
196+ }
197+ _aaBar
198+ }
199+ val aaBarValues = aaBar.values
187200
188201 val effectiveRegParam = regParam / bStd
189202 val effectiveL1RegParam = elasticNetParam * effectiveRegParam
190203 val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
191204
192205 // add L2 regularization to diagonals
193206 var i = 0
194- j = 2
207+ var j = 2
195208 while (i < triK) {
196209 var lambda = effectiveL2RegParam
197210 if (! standardizeFeatures) {
198- val std = aStd (j - 2 )
211+ val std = aStdValues (j - 2 )
199212 if (std != 0.0 ) {
200213 lambda /= (std * std)
201214 } else {
@@ -209,8 +222,9 @@ private[ml] class WeightedLeastSquares(
209222 i += j
210223 j += 1
211224 }
212- val aa = getAtA(aaBar.values, aBar.values)
213- val ab = getAtB(abBar.values, bBarStd)
225+
226+ val aa = getAtA(aaBarValues, aBarValues)
227+ val ab = getAtB(abBarValues, bBar)
214228
215229 val solver = if ((solverType == WeightedLeastSquares .Auto && elasticNetParam != 0.0 &&
216230 regParam != 0.0 ) || (solverType == WeightedLeastSquares .QuasiNewton )) {
@@ -237,22 +251,23 @@ private[ml] class WeightedLeastSquares(
237251 val solution = solver match {
238252 case cholesky : CholeskySolver =>
239253 try {
240- cholesky.solve(bBarStd, bbBarStd , ab, aa, aBar)
254+ cholesky.solve(bBar, bbBar , ab, aa, aBar)
241255 } catch {
242256 // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
243- // quasi-newton solver
257+ // Quasi-Newton solver.
244258 case _ : SingularMatrixException if solverType == WeightedLeastSquares .Auto =>
245259 logWarning(" Cholesky solver failed due to singular covariance matrix. " +
246260 " Retrying with Quasi-Newton solver." )
247261 // ab and aa were modified in place, so reconstruct them
248- val _aa = getAtA(aaBar.values, aBar.values )
249- val _ab = getAtB(abBar.values, bBarStd )
262+ val _aa = getAtA(aaBarValues, aBarValues )
263+ val _ab = getAtB(abBarValues, bBar )
250264 val newSolver = new QuasiNewtonSolver (fitIntercept, maxIter, tol, None )
251- newSolver.solve(bBarStd, bbBarStd , _ab, _aa, aBar)
265+ newSolver.solve(bBar, bbBar , _ab, _aa, aBar)
252266 }
253267 case qn : QuasiNewtonSolver =>
254- qn.solve(bBarStd, bbBarStd , ab, aa, aBar)
268+ qn.solve(bBar, bbBar , ab, aa, aBar)
255269 }
270+
256271 val (coefficientArray, intercept) = if (fitIntercept) {
257272 (solution.coefficients.slice(0 , solution.coefficients.length - 1 ),
258273 solution.coefficients.last * bStd)
@@ -271,7 +286,11 @@ private[ml] class WeightedLeastSquares(
271286 // aaInv is a packed upper triangular matrix, here we get all elements on diagonal
272287 val diagInvAtWA = solution.aaInv.map { inv =>
273288 new DenseVector ((1 to k).map { i =>
274- val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1 ) * aStdValues(i - 1 )
289+ val multiplier = if (i == k && fitIntercept) {
290+ 1.0
291+ } else {
292+ aStdValues(i - 1 ) * aStdValues(i - 1 )
293+ }
275294 inv(i + (i - 1 ) * i / 2 - 1 ) / (wSum * multiplier)
276295 }.toArray)
277296 }.getOrElse(new DenseVector (Array (0D )))
@@ -280,7 +299,7 @@ private[ml] class WeightedLeastSquares(
280299 solution.objectiveHistory.getOrElse(Array (0D )))
281300 }
282301
283- /** Construct A^T^ A from summary statistics . */
302+ /** Construct A^T^ A (append bias if necessary) . */
284303 private def getAtA (aaBar : Array [Double ], aBar : Array [Double ]): DenseVector = {
285304 if (fitIntercept) {
286305 new DenseVector (Array .concat(aaBar, aBar, Array (1.0 )))
@@ -289,7 +308,7 @@ private[ml] class WeightedLeastSquares(
289308 }
290309 }
291310
292- /** Construct A^T^ b from summary statistics . */
311+ /** Construct A^T^ b (append bias if necessary) . */
293312 private def getAtB (abBar : Array [Double ], bBar : Double ): DenseVector = {
294313 if (fitIntercept) {
295314 new DenseVector (Array .concat(abBar, Array (bBar)))
0 commit comments