Skip to content

Commit d72c679

Browse files
author
DB Tsai
committed
Using Breeze's states to get the loss.
1 parent cd4ed29 commit d72c679

File tree

1 file changed

+15
-11
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/optimization

1 file changed

+15
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,25 @@ object LBFGS extends Logging {
171171
val miniBatchSize = numExamples * miniBatchFraction
172172

173173
val costFun =
174-
new CostFun(data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize)
174+
new CostFun(data, gradient, updater, regParam, miniBatchFraction, miniBatchSize)
175175

176176
val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
177177

178-
val weights = Vectors.fromBreeze(
179-
lbfgs.minimize(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector))
178+
val states = lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
180179

181-
logInfo("LBFGS.runMiniBatchSGD finished. Last 10 losses %s".format(
180+
/**
181+
* NOTE: lossSum and loss is computed using the weights from the previous iteration
182+
* and regVal is the regularization value computed in the previous iteration as well.
183+
*/
184+
var state = states.next()
185+
while(states.hasNext) {
186+
lossHistory.append(state.value)
187+
state = states.next()
188+
}
189+
lossHistory.append(state.value)
190+
val weights = Vectors.fromBreeze(state.x)
191+
192+
logInfo("LBFGS.runMiniBatchLBFGS finished. Last 10 losses %s".format(
182193
lossHistory.takeRight(10).mkString(", ")))
183194

184195
(weights, lossHistory.toArray)
@@ -194,7 +205,6 @@ object LBFGS extends Logging {
194205
updater: Updater,
195206
regParam: Double,
196207
miniBatchFraction: Double,
197-
lossHistory: ArrayBuffer[Double],
198208
miniBatchSize: Double) extends DiffFunction[BDV[Double]] {
199209

200210
private var i = 0
@@ -248,12 +258,6 @@ object LBFGS extends Logging {
248258
// gradientTotal = gradientSum / miniBatchSize + gradientTotal
249259
axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
250260

251-
/**
252-
* NOTE: lossSum and loss is computed using the weights from the previous iteration
253-
* and regVal is the regularization value computed in the previous iteration as well.
254-
*/
255-
lossHistory.append(loss)
256-
257261
i += 1
258262

259263
(loss, gradientTotal)

0 commit comments

Comments
 (0)