@@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization
1919
2020import scala .math ._
2121
22- import breeze .linalg .{norm => brzNorm }
22+ import breeze .linalg .{norm => brzNorm , axpy => brzAxpy , Vector => BV }
2323
2424import org .apache .spark .mllib .linalg .{Vectors , Vector }
2525
@@ -70,7 +70,9 @@ class SimpleUpdater extends Updater {
7070 iter : Int ,
7171 regParam : Double ): (Vector , Double ) = {
7272 val thisIterStepSize = stepSize / math.sqrt(iter)
73- val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize
73+ val brzWeights : BV [Double ] = weightsOld.toBreeze.toDenseVector
74+ brzAxpy(- thisIterStepSize, gradient.toBreeze, brzWeights)
75+
7476 (Vectors .fromBreeze(brzWeights), 0 )
7577 }
7678}
@@ -102,7 +104,8 @@ class L1Updater extends Updater {
102104 regParam : Double ): (Vector , Double ) = {
103105 val thisIterStepSize = stepSize / math.sqrt(iter)
104106 // Take gradient step
105- val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize
107+ val brzWeights : BV [Double ] = weightsOld.toBreeze.toDenseVector
108+ brzAxpy(- thisIterStepSize, gradient.toBreeze, brzWeights)
106109 // Apply proximal operator (soft thresholding)
107110 val shrinkageVal = regParam * thisIterStepSize
108111 var i = 0
@@ -133,8 +136,9 @@ class SquaredL2Updater extends Updater {
133136 // w' = w - thisIterStepSize * (gradient + regParam * w)
134137 // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
135138 val thisIterStepSize = stepSize / math.sqrt(iter)
136- val brzWeights = weightsOld.toBreeze * (1.0 - thisIterStepSize * regParam) -
137- (gradient.toBreeze * thisIterStepSize)
139+ val brzWeights : BV [Double ] = weightsOld.toBreeze.toDenseVector
140+ brzWeights :*= (1.0 - thisIterStepSize * regParam)
141+ brzAxpy(- thisIterStepSize, gradient.toBreeze, brzWeights)
138142 val norm = brzNorm(brzWeights, 2.0 )
139143
140144 (Vectors .fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
0 commit comments