@@ -171,14 +171,25 @@ object LBFGS extends Logging {
171171 val miniBatchSize = numExamples * miniBatchFraction
172172
173173 val costFun =
174- new CostFun (data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize)
174+ new CostFun (data, gradient, updater, regParam, miniBatchFraction, miniBatchSize)
175175
176176 val lbfgs = new BreezeLBFGS [BDV [Double ]](maxNumIterations, numCorrections, convergenceTol)
177177
178- val weights = Vectors .fromBreeze(
179- lbfgs.minimize(new CachedDiffFunction (costFun), initialWeights.toBreeze.toDenseVector))
178+ val states = lbfgs.iterations(new CachedDiffFunction (costFun), initialWeights.toBreeze.toDenseVector)
180179
181- logInfo(" LBFGS.runMiniBatchSGD finished. Last 10 losses %s" .format(
180+ /**
181+ * NOTE: lossSum and loss is computed using the weights from the previous iteration
182+ * and regVal is the regularization value computed in the previous iteration as well.
183+ */
184+ var state = states.next()
185+ while (states.hasNext) {
186+ lossHistory.append(state.value)
187+ state = states.next()
188+ }
189+ lossHistory.append(state.value)
190+ val weights = Vectors .fromBreeze(state.x)
191+
192+ logInfo(" LBFGS.runMiniBatchLBFGS finished. Last 10 losses %s" .format(
182193 lossHistory.takeRight(10 ).mkString(" , " )))
183194
184195 (weights, lossHistory.toArray)
@@ -194,7 +205,6 @@ object LBFGS extends Logging {
194205 updater : Updater ,
195206 regParam : Double ,
196207 miniBatchFraction : Double ,
197- lossHistory : ArrayBuffer [Double ],
198208 miniBatchSize : Double ) extends DiffFunction [BDV [Double ]] {
199209
200210 private var i = 0
@@ -248,12 +258,6 @@ object LBFGS extends Logging {
248258 // gradientTotal = gradientSum / miniBatchSize + gradientTotal
249259 axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
250260
251- /**
252- * NOTE: lossSum and loss is computed using the weights from the previous iteration
253- * and regVal is the regularization value computed in the previous iteration as well.
254- */
255- lossHistory.append(loss)
256-
257261 i += 1
258262
259263 (loss, gradientTotal)
0 commit comments