Skip to content

Commit e981396

Browse files
committed
use axpy in Updater
1 parent db808a1 commit e981396

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.optimization
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22+
import breeze.linalg.{Vector => BV, DenseVector => BDV}
23+
2224
import org.apache.spark.Logging
2325
import org.apache.spark.rdd.RDD
2426
import org.apache.spark.mllib.linalg.{Vectors, Vector}
@@ -157,11 +159,16 @@ object GradientDescent extends Logging {
157159
for (i <- 1 to numIterations) {
158160
// Sample a subset (fraction miniBatchFraction) of the total data
159161
// compute and sum up the subgradients on this subset (this is one map-reduce)
160-
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map {
161-
case (y, features) =>
162-
val (grad, loss) = gradient.compute(features, y, weights)
163-
(grad.toBreeze, loss)
164-
}.reduce((a, b) => (a._1 += b._1, a._2 + b._2))
162+
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
163+
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
164+
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
165+
val (g, l) = gradient.compute(features, label, weights)
166+
(grad += g.toBreeze, loss + l)
167+
},
168+
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
169+
(grad1 += grad2, loss1 + loss2)
170+
}
171+
)
165172

166173
/**
167174
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration

mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization
1919

2020
import scala.math._
2121

22-
import breeze.linalg.{norm => brzNorm}
22+
import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV}
2323

2424
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2525

@@ -70,7 +70,9 @@ class SimpleUpdater extends Updater {
7070
iter: Int,
7171
regParam: Double): (Vector, Double) = {
7272
val thisIterStepSize = stepSize / math.sqrt(iter)
73-
val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize
73+
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
74+
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
75+
7476
(Vectors.fromBreeze(brzWeights), 0)
7577
}
7678
}
@@ -102,7 +104,8 @@ class L1Updater extends Updater {
102104
regParam: Double): (Vector, Double) = {
103105
val thisIterStepSize = stepSize / math.sqrt(iter)
104106
// Take gradient step
105-
val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize
107+
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
108+
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
106109
// Apply proximal operator (soft thresholding)
107110
val shrinkageVal = regParam * thisIterStepSize
108111
var i = 0
@@ -133,8 +136,9 @@ class SquaredL2Updater extends Updater {
133136
// w' = w - thisIterStepSize * (gradient + regParam * w)
134137
// w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
135138
val thisIterStepSize = stepSize / math.sqrt(iter)
136-
val brzWeights = weightsOld.toBreeze * (1.0 - thisIterStepSize * regParam) -
137-
(gradient.toBreeze * thisIterStepSize)
139+
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
140+
brzWeights :*= (1.0 - thisIterStepSize * regParam)
141+
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
138142
val norm = brzNorm(brzWeights, 2.0)
139143

140144
(Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)

0 commit comments

Comments
 (0)