From 984b18e21396eae84656e15da3539ff3b5f3bf4a Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 4 Apr 2014 17:06:50 -0700 Subject: [PATCH] L-BFGS Optimizer based on Breeze's implementation. Also fixed indentation issue in GradientDescent optimizer. --- .../mllib/optimization/GradientDescent.scala | 28 +- .../spark/mllib/optimization/LBFGS.scala | 263 ++++++++++++++++++ .../spark/mllib/optimization/LBFGSSuite.scala | 203 ++++++++++++++ 3 files changed, 480 insertions(+), 14 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index f60417f21d4b9..c75909bac9248 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -34,8 +34,8 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} */ @DeveloperApi class GradientDescent(private var gradient: Gradient, private var updater: Updater) - extends Optimizer with Logging -{ + extends Optimizer with Logging { + private var stepSize: Double = 1.0 private var numIterations: Int = 100 private var regParam: Double = 0.0 @@ -139,26 +139,26 @@ object GradientDescent extends Logging { * stochastic loss computed for every iteration. */ def runMiniBatchSGD( - data: RDD[(Double, Vector)], - gradient: Gradient, - updater: Updater, - stepSize: Double, - numIterations: Int, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Vector): (Vector, Array[Double]) = { + data: RDD[(Double, Vector)], + gradient: Gradient, + updater: Updater, + stepSize: Double, + numIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Vector): (Vector, Array[Double]) = { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) - val nexamples: Long = data.count() - val miniBatchSize = nexamples * miniBatchFraction + val numExamples = data.count() + val miniBatchSize = numExamples * miniBatchFraction // Initialize weights as a column vector var weights = Vectors.dense(initialWeights.toArray) /** - * For the first iteration, the regVal will be initialized as sum of sqrt of - * weights if it's L2 update; for L1 update; the same logic is followed. + * For the first iteration, the regVal will be initialized as sum of weight squares + * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala new file mode 100644 index 0000000000000..969a0c5f7c953 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import scala.collection.mutable.ArrayBuffer + +import breeze.linalg.{DenseVector => BDV, axpy} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vectors, Vector} + +/** + * :: DeveloperApi :: + * Class used to solve an optimization problem using Limited-memory BFGS. + * Reference: [[http://en.wikipedia.org/wiki/Limited-memory_BFGS]] + * @param gradient Gradient function to be used. + * @param updater Updater to be used to update weights after every iteration. + */ +@DeveloperApi +class LBFGS(private var gradient: Gradient, private var updater: Updater) + extends Optimizer with Logging { + + private var numCorrections = 10 + private var convergenceTol = 1E-4 + private var maxNumIterations = 100 + private var regParam = 0.0 + private var miniBatchFraction = 1.0 + + /** + * Set the number of corrections used in the LBFGS update. Default 10. + * Values of numCorrections less than 3 are not recommended; large values + * of numCorrections will result in excessive computing time. + * 3 < numCorrections < 10 is recommended. + * Restriction: numCorrections > 0 + */ + def setNumCorrections(corrections: Int): this.type = { + assert(corrections > 0) + this.numCorrections = corrections + this + } + + /** + * Set fraction of data to be used for each L-BFGS iteration. Default 1.0. + */ + def setMiniBatchFraction(fraction: Double): this.type = { + this.miniBatchFraction = fraction + this + } + + /** + * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. + * Smaller value will lead to higher accuracy with the cost of more iterations. + */ + def setConvergenceTol(tolerance: Int): this.type = { + this.convergenceTol = tolerance + this + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setMaxNumIterations(iters: Int): this.type = { + this.maxNumIterations = iters + this + } + + /** + * Set the regularization parameter. Default 0.0. + */ + def setRegParam(regParam: Double): this.type = { + this.regParam = regParam + this + } + + /** + * Set the gradient function (of the loss function of one single data example) + * to be used for L-BFGS. + */ + def setGradient(gradient: Gradient): this.type = { + this.gradient = gradient + this + } + + /** + * Set the updater function to actually perform a gradient step in a given direction. + * The updater is responsible to perform the update from the regularization term as well, + * and therefore determines what kind or regularization is used, if any. + */ + def setUpdater(updater: Updater): this.type = { + this.updater = updater + this + } + + override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { + val (weights, _) = LBFGS.runMiniBatchLBFGS( + data, + gradient, + updater, + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + miniBatchFraction, + initialWeights) + weights + } + +} + +/** + * :: DeveloperApi :: + * Top-level method to run L-BFGS. + */ +@DeveloperApi +object LBFGS extends Logging { + /** + * Run Limited-memory BFGS (L-BFGS) in parallel using mini batches. + * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data + * in order to compute a gradient estimate. + * Sampling, and averaging the subgradients over this subset is performed using one standard + * spark map-reduce in each iteration. + * + * @param data - Input data for L-BFGS. RDD of the set of data examples, each of + * the form (label, [feature values]). + * @param gradient - Gradient object (used to compute the gradient of the loss function of + * one single data example) + * @param updater - Updater function to actually perform a gradient step in a given direction. + * @param numCorrections - The number of corrections used in the L-BFGS update. + * @param convergenceTol - The convergence tolerance of iterations for L-BFGS + * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. + * @param regParam - Regularization parameter + * @param miniBatchFraction - Fraction of the input data set that should be used for + * one iteration of L-BFGS. Default value 1.0. + * + * @return A tuple containing two elements. The first element is a column matrix containing + * weights for every feature, and the second element is an array containing the loss + * computed for every iteration. + */ + def runMiniBatchLBFGS( + data: RDD[(Double, Vector)], + gradient: Gradient, + updater: Updater, + numCorrections: Int, + convergenceTol: Double, + maxNumIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Vector): (Vector, Array[Double]) = { + + val lossHistory = new ArrayBuffer[Double](maxNumIterations) + + val numExamples = data.count() + val miniBatchSize = numExamples * miniBatchFraction + + val costFun = + new CostFun(data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize) + + val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol) + + val weights = Vectors.fromBreeze( + lbfgs.minimize(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)) + + logInfo("LBFGS.runMiniBatchSGD finished. Last 10 losses %s".format( + lossHistory.takeRight(10).mkString(", "))) + + (weights, lossHistory.toArray) + } + + /** + * CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient + * at a particular point (weights). It's used in Breeze's convex optimization routines. + */ + private class CostFun( + data: RDD[(Double, Vector)], + gradient: Gradient, + updater: Updater, + regParam: Double, + miniBatchFraction: Double, + lossHistory: ArrayBuffer[Double], + miniBatchSize: Double) extends DiffFunction[BDV[Double]] { + + private var i = 0 + + override def calculate(weights: BDV[Double]) = { + // Have a local copy to avoid the serialization of CostFun object which is not serializable. + val localData = data + val localGradient = gradient + + val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 42 + i) + .aggregate((BDV.zeros[Double](weights.size), 0.0))( + seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => + val l = localGradient.compute( + features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad)) + (grad, loss + l) + }, + combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + (grad1 += grad2, loss1 + loss2) + }) + + /** + * regVal is sum of weight squares if it's L2 updater; + * for other updater, the same logic is followed. + */ + val regVal = updater.compute( + Vectors.fromBreeze(weights), + Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + + val loss = lossSum / miniBatchSize + regVal + /** + * It will return the gradient part of regularization using updater. + * + * Given the input parameters, the updater basically does the following, + * + * w' = w - thisIterStepSize * (gradient + regGradient(w)) + * Note that regGradient is function of w + * + * If we set gradient = 0, thisIterStepSize = 1, then + * + * regGradient(w) = w - w' + * + * TODO: We need to clean it up by separating the logic of regularization out + * from updater to regularizer. + */ + // The following gradientTotal is actually the regularization part of gradient. + // Will add the gradientSum computed from the data with weights in the next step. + val gradientTotal = weights - updater.compute( + Vectors.fromBreeze(weights), + Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze + + // gradientTotal = gradientSum / miniBatchSize + gradientTotal + axpy(1.0 / miniBatchSize, gradientSum, gradientTotal) + + /** + * NOTE: lossSum and loss is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + lossHistory.append(loss) + + i += 1 + + (loss, gradientTotal) + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala new file mode 100644 index 0000000000000..f33770aed30bd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.LocalSparkContext + +class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val initialB = -1.0 + val initialWeights = Array(initialB) + + val gradient = new LogisticGradient() + val numCorrections = 10 + val miniBatchFrac = 1.0 + + val simpleUpdater = new SimpleUpdater() + val squaredL2Updater = new SquaredL2Updater() + + // Add an extra variable consisting of all 1.0's for the intercept. + val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) + val data = testData.map { case LabeledPoint(label, features) => + label -> Vectors.dense(1.0, features.toArray: _*) + } + + lazy val dataRDD = sc.parallelize(data, 2).cache() + + def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { + math.abs(x - y) / (math.abs(y) + 1e-15) < tol + } + + test("LBFGS loss should be decreasing and match the result of Gradient Descent.") { + val regParam = 0 + + val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val (_, loss) = LBFGS.runMiniBatchLBFGS( + dataRDD, + gradient, + simpleUpdater, + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // Since the cost function is convex, the loss is guaranteed to be monotonically decreasing + // with L-BFGS optimizer. + // (SGD doesn't guarantee this, and the loss will be fluctuating in the optimization process.) + assert((loss, loss.tail).zipped.forall(_ > _), "loss should be monotonically decreasing.") + + val stepSize = 1.0 + // Well, GD converges slower, so it requires more iterations! + val numGDIterations = 50 + val (_, lossGD) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + simpleUpdater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // GD converges a way slower than L-BFGS. To achieve 1% difference, + // it requires 90 iterations in GD. No matter how hard we increase + // the number of iterations in GD here, the lossGD will be always + // larger than lossLBFGS. This is based on observation, no theoretically guaranteed + assert(Math.abs((lossGD.last - loss.last) / loss.last) < 0.02, + "LBFGS should match GD result within 2% difference.") + } + + test("LBFGS and Gradient Descent with L2 regularization should get the same result.") { + val regParam = 0.2 + + // Prepare another non-zero weights to compare the loss in the first iteration. + val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS( + dataRDD, + gradient, + squaredL2Updater, + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + val numGDIterations = 50 + val stepSize = 1.0 + val (weightGD, lossGD) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + squaredL2Updater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + assert(compareDouble(lossGD(0), lossLBFGS(0)), + "The first losses of LBFGS and GD should be the same.") + + // The 2% difference here is based on observation, but is not theoretically guaranteed. + assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02), + "The last losses of LBFGS and GD should be within 2% difference.") + + assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && + compareDouble(weightLBFGS(1), weightGD(1), 0.02), + "The weight differences between LBFGS and GD should be within 2%.") + } + + test("The convergence criteria should work as we expect.") { + val regParam = 0.0 + + /** + * For the first run, we set the convergenceTol to 0.0, so that the algorithm will + * run up to the maxNumIterations which is 8 here. + */ + val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0) + val maxNumIterations = 8 + var convergenceTol = 0.0 + + val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS( + dataRDD, + gradient, + squaredL2Updater, + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // Note that the first loss is computed with initial weights, + // so the total numbers of loss will be numbers of iterations + 1 + assert(lossLBFGS1.length == 9) + + convergenceTol = 0.1 + val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS( + dataRDD, + gradient, + squaredL2Updater, + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed. + assert(lossLBFGS2.length == 4) + assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol) + + convergenceTol = 0.01 + val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS( + dataRDD, + gradient, + squaredL2Updater, + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // With smaller convergenceTol, it takes more steps. + assert(lossLBFGS3.length > lossLBFGS2.length) + + // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed. + assert(lossLBFGS3.length == 6) + assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) + } +}