Skip to content

Commit cecf02c

Browse files
committed
Changed setConvergenceTol'' to specify tolerance with a parameter of type Double. For the reason and the problem caused by an Int parameter, please check https://issues.apache.org/jira/browse/SPARK-2163. Added a test in LBFGSSuite for validating that optimizing via class LBFGS produces the same results as calling runLBFGS from object LBFGS. Keep the indentations and styles correct.
1 parent 273afcb commit cecf02c

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
6060
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
6161
* Smaller value will lead to higher accuracy with the cost of more iterations.
6262
*/
63-
def setConvergenceTol(tolerance: Int): this.type = {
63+
def setConvergenceTol(tolerance: Double): this.type = {
6464
this.convergenceTol = tolerance
6565
this
6666
}

mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
195195
assert(lossLBFGS3.length == 6)
196196
assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol)
197197
}
198+
199+
test("Optimize via class LBFGS.") {
200+
val regParam = 0.2
201+
202+
// Prepare another non-zero weights to compare the loss in the first iteration.
203+
val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
204+
val convergenceTol = 1e-12
205+
val maxNumIterations = 10
206+
207+
val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater)
208+
.setNumCorrections(numCorrections)
209+
.setConvergenceTol(convergenceTol)
210+
.setMaxNumIterations(maxNumIterations)
211+
.setRegParam(regParam)
212+
213+
val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept)
214+
215+
val numGDIterations = 50
216+
val stepSize = 1.0
217+
val (weightGD, _) = GradientDescent.runMiniBatchSGD(
218+
dataRDD,
219+
gradient,
220+
squaredL2Updater,
221+
stepSize,
222+
numGDIterations,
223+
regParam,
224+
miniBatchFrac,
225+
initialWeightsWithIntercept)
226+
227+
// for class LBFGS and the optimize method, we only look at the weights
228+
assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
229+
compareDouble(weightLBFGS(1), weightGD(1), 0.02),
230+
"The weight differences between LBFGS and GD should be within 2%.")
231+
}
198232
}

0 commit comments

Comments
 (0)