From 2946930ec3de0e0a34e07d065c954d7aabacd4ba Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 18 Jul 2014 19:15:37 -0700 Subject: [PATCH] initial work --- .../classification/LogisticRegression.scala | 12 +- .../spark/mllib/classification/SVM.scala | 5 +- .../mllib/optimization/GradientDescent.scala | 86 ++++++----- .../spark/mllib/optimization/LBFGS.scala | 72 +++------ .../mllib/optimization/Regularizer.scala | 140 ++++++++++++++++++ .../spark/mllib/optimization/Updater.scala | 32 ++-- .../apache/spark/mllib/regression/Lasso.scala | 5 +- .../mllib/regression/LinearRegression.scala | 4 +- .../mllib/regression/RidgeRegression.scala | 5 +- .../optimization/GradientDescentSuite.scala | 15 +- .../spark/mllib/optimization/LBFGSSuite.scala | 27 ++-- 11 files changed, 261 insertions(+), 142 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Regularizer.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 90aa8ac998ba..291b896dbb9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -77,23 +77,21 @@ class LogisticRegressionModel private[mllib] ( class LogisticRegressionWithSGD private ( private var stepSize: Double, private var numIterations: Int, - private var regParam: Double, private var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() - override val optimizer = new GradientDescent(gradient, updater) + private val regularizer = new SimpleRegularizer() + override val optimizer = new GradientDescent(gradient, regularizer) .setStepSize(stepSize) .setNumIterations(numIterations) - .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) override protected val validators = List(DataValidators.binaryLabelValidator) /** * Construct a LogisticRegression object with default parameters */ - def this() = this(1.0, 100, 0.0, 1.0) + def this() = this(1.0, 100, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) @@ -128,7 +126,7 @@ object LogisticRegressionWithSGD { stepSize: Double, miniBatchFraction: Double, initialWeights: Vector): LogisticRegressionModel = { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) + new LogisticRegressionWithSGD(stepSize, numIterations, miniBatchFraction) .run(input, initialWeights) } @@ -149,7 +147,7 @@ object LogisticRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double): LogisticRegressionModel = { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) + new LogisticRegressionWithSGD(stepSize, numIterations, miniBatchFraction) .run(input) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 316ecd713b71..4110aecadc57 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -83,11 +83,10 @@ class SVMWithSGD private ( extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { private val gradient = new HingeGradient() - private val updater = new SquaredL2Updater() - override val optimizer = new GradientDescent(gradient, updater) + private val regularizer = new L2Regularizer(regParam) + override val optimizer = new GradientDescent(gradient, regularizer) .setStepSize(stepSize) .setNumIterations(numIterations) - .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) override protected val validators = List(DataValidators.binaryLabelValidator) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 7030eeabe400..a8cb0a344e65 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV} +import breeze.linalg.{axpy => brzAxpy, DenseVector => BDV} import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging @@ -29,10 +29,11 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * Class used to solve an optimization problem using Gradient Descent. * @param gradient Gradient function to be used. - * @param updater Updater to be used to update weights after every iteration. + * @param regularizer Regularizer to be used for regularization. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) - extends Optimizer with Logging { +class GradientDescent private[mllib] ( + private var gradient: Gradient, + private var regularizer: Regularizer) extends Optimizer with Logging { private var stepSize: Double = 1.0 private var numIterations: Int = 100 @@ -69,12 +70,27 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va /** * Set the regularization parameter. Default 0.0. + * This is deprecated, and the strength of regularization + * will be controlled by regularizer. */ + @Deprecated def setRegParam(regParam: Double): this.type = { this.regParam = regParam this } + /** + * Set the updater function to actually perform a gradient step in a given direction. + * The updater is responsible to perform the update from the regularization term as well, + * and therefore determines what kind or regularization is used, if any. + * This is deprecated, please use regularizer instead. + */ + @Deprecated + def setUpdater(updater: Updater): this.type = { + // this.updater = updater + this + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for SGD. @@ -84,14 +100,11 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va this } - /** - * Set the updater function to actually perform a gradient step in a given direction. - * The updater is responsible to perform the update from the regularization term as well, - * and therefore determines what kind or regularization is used, if any. + * Set the regularizer object to perform the regularization. */ - def setUpdater(updater: Updater): this.type = { - this.updater = updater + def setRegularizer(regularizer: Regularizer): this.type = { + this.regularizer = regularizer this } @@ -107,10 +120,9 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va val (weights, _) = GradientDescent.runMiniBatchSGD( data, gradient, - updater, + regularizer, stepSize, numIterations, - regParam, miniBatchFraction, initialWeights) weights @@ -124,6 +136,18 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va */ @DeveloperApi object GradientDescent extends Logging { + +// def runMiniBatchSGD( +// data: RDD[(Double, Vector)], +// gradient: Gradient, +// regularizer: Regularizer, +// stepSize: Double, +// numIterations: Int, +// regParam: Double, +// miniBatchFraction: Double, +// initialWeights: Vector): (Vector, Array[Double]) = { +// + /** * Run stochastic gradient descent (SGD) in parallel using mini batches. * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data @@ -135,10 +159,9 @@ object GradientDescent extends Logging { * the form (label, [feature values]). * @param gradient - Gradient object (used to compute the gradient of the loss function of * one single data example) - * @param updater - Updater function to actually perform a gradient step in a given direction. + * @param regularizer - Updater function to actually perform a gradient step in a given direction. * @param stepSize - initial step size for the first step * @param numIterations - number of iterations that SGD should be run. - * @param regParam - regularization parameter * @param miniBatchFraction - fraction of the input data set that should be used for * one iteration of SGD. Default value 1.0. * @@ -149,10 +172,9 @@ object GradientDescent extends Logging { def runMiniBatchSGD( data: RDD[(Double, Vector)], gradient: Gradient, - updater: Updater, + regularizer: Regularizer, stepSize: Double, numIterations: Int, - regParam: Double, miniBatchFraction: Double, initialWeights: Vector): (Vector, Array[Double]) = { @@ -162,42 +184,40 @@ object GradientDescent extends Logging { val miniBatchSize = numExamples * miniBatchFraction // Initialize weights as a column vector - var weights = Vectors.dense(initialWeights.toArray) - - /** - * For the first iteration, the regVal will be initialized as sum of weight squares - * if it's L2 updater; for L1 updater, the same logic is followed. - */ - var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + val brzWeights = new BDV[Double](initialWeights.toArray.clone()) for (i <- 1 to numIterations) { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](weights.size), 0.0))( + .aggregate((BDV.zeros[Double](brzWeights.length), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad)) + val l = gradient.compute(features, label, + Vectors.fromBreeze(brzWeights), Vectors.fromBreeze(grad)) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => (grad1 += grad2, loss1 + loss2) }) + gradientSum :*= (1.0 / miniBatchSize) + + val regVal = regularizer.compute(Vectors.fromBreeze(brzWeights), + Vectors.fromBreeze(gradientSum)) + /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. + * lossSum is computed using the weights from the previous iteration, and regVal is + * the regularization value also computed with the weights from previous iteration. */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 + + val thisIterStepSize = stepSize / math.sqrt(i) + brzAxpy(-thisIterStepSize, gradientSum, brzWeights) } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( stochasticLossHistory.takeRight(10).mkString(", "))) - (weights, stochasticLossHistory.toArray) + (Vectors.fromBreeze(brzWeights), stochasticLossHistory.toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 7bbed9c8fdbe..eb1ecf0697cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, axpy} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} import org.apache.spark.annotation.DeveloperApi @@ -32,16 +32,16 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Class used to solve an optimization problem using Limited-memory BFGS. * Reference: [[http://en.wikipedia.org/wiki/Limited-memory_BFGS]] * @param gradient Gradient function to be used. - * @param updater Updater to be used to update weights after every iteration. + * @param regularizer Regularizer to be used for regularization. */ @DeveloperApi -class LBFGS(private var gradient: Gradient, private var updater: Updater) - extends Optimizer with Logging { +class LBFGS private[mllib] ( + private var gradient: Gradient, + private var regularizer: Regularizer) extends Optimizer with Logging { private var numCorrections = 10 private var convergenceTol = 1E-4 private var maxNumIterations = 100 - private var regParam = 0.0 /** * Set the number of corrections used in the LBFGS update. Default 10. @@ -76,8 +76,9 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the regularization parameter. Default 0.0. */ + @Deprecated def setRegParam(regParam: Double): this.type = { - this.regParam = regParam + // this.regParam = regParam this } @@ -95,8 +96,9 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * The updater is responsible to perform the update from the regularization term as well, * and therefore determines what kind or regularization is used, if any. */ + @Deprecated def setUpdater(updater: Updater): this.type = { - this.updater = updater + //this.updater = updater this } @@ -104,11 +106,10 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) val (weights, _) = LBFGS.runLBFGS( data, gradient, - updater, + regularizer, numCorrections, convergenceTol, maxNumIterations, - regParam, initialWeights) weights } @@ -130,11 +131,10 @@ object LBFGS extends Logging { * the form (label, [feature values]). * @param gradient - Gradient object (used to compute the gradient of the loss function of * one single data example) - * @param updater - Updater function to actually perform a gradient step in a given direction. + * @param regularizer - Updater function to actually perform a gradient step in a given direction. * @param numCorrections - The number of corrections used in the L-BFGS update. * @param convergenceTol - The convergence tolerance of iterations for L-BFGS * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. - * @param regParam - Regularization parameter * * @return A tuple containing two elements. The first element is a column matrix containing * weights for every feature, and the second element is an array containing the loss @@ -143,19 +143,17 @@ object LBFGS extends Logging { def runLBFGS( data: RDD[(Double, Vector)], gradient: Gradient, - updater: Updater, + regularizer: Regularizer, numCorrections: Int, convergenceTol: Double, maxNumIterations: Int, - regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { val lossHistory = new ArrayBuffer[Double](maxNumIterations) val numExamples = data.count() - val costFun = - new CostFun(data, gradient, updater, regParam, numExamples) + val costFun = new CostFun(data, gradient, regularizer, numExamples) val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol) @@ -187,8 +185,7 @@ object LBFGS extends Logging { private class CostFun( data: RDD[(Double, Vector)], gradient: Gradient, - updater: Updater, - regParam: Double, + regularizer: Regularizer, numExamples: Long) extends DiffFunction[BDV[Double]] { private var i = 0 @@ -198,7 +195,7 @@ object LBFGS extends Logging { val localData = data val localGradient = gradient - val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))( + var (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad)) @@ -208,42 +205,15 @@ object LBFGS extends Logging { (grad1 += grad2, loss1 + loss2) }) - /** - * regVal is sum of weight squares if it's L2 updater; - * for other updater, the same logic is followed. - */ - val regVal = updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 - - val loss = lossSum / numExamples + regVal - /** - * It will return the gradient part of regularization using updater. - * - * Given the input parameters, the updater basically does the following, - * - * w' = w - thisIterStepSize * (gradient + regGradient(w)) - * Note that regGradient is function of w - * - * If we set gradient = 0, thisIterStepSize = 1, then - * - * regGradient(w) = w - w' - * - * TODO: We need to clean it up by separating the logic of regularization out - * from updater to regularizer. - */ - // The following gradientTotal is actually the regularization part of gradient. - // Will add the gradientSum computed from the data with weights in the next step. - val gradientTotal = weights - updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze - - // gradientTotal = gradientSum / numExamples + gradientTotal - axpy(1.0 / numExamples, gradientSum, gradientTotal) + gradientSum :*= (1.0 / numExamples) + lossSum *= (1.0 / numExamples) + + lossSum += regularizer.compute(Vectors.fromBreeze(weights), + Vectors.fromBreeze(gradientSum)) i += 1 - (loss, gradientTotal) + (lossSum, gradientSum) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Regularizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Regularizer.scala new file mode 100644 index 000000000000..996d570bc34f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Regularizer.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import scala.collection.mutable.ListBuffer +import scala.math._ + +import breeze.linalg.{DenseVector => BDV, Vector => BV} + +import org.apache.spark.mllib.linalg.{Vectors, Vector} + +abstract class Regularizer extends Serializable { + var isSmooth: Boolean = true + + def add(that: Regularizer): CompositeRegularizer = { + (new CompositeRegularizer).add(this).add(that) + } + + def compute(weights: Vector, cumGradient: Vector): Double +} + +class SimpleRegularizer extends Regularizer { + isSmooth = true + + override def compute(weights: Vector, cumGradient: Vector): Double = 0 +} + +class CompositeRegularizer extends Regularizer { + isSmooth = true + + protected val regularizers = ListBuffer[Regularizer]() + + override def add(that: Regularizer): this.type = { + if (this.isSmooth && !that.isSmooth) isSmooth = false + regularizers.append(that) + this + } + + override def compute(weights: Vector, cumGradient: Vector): Double = { + if (regularizers.isEmpty) { + 0.0 + } else { + regularizers.foldLeft(0.0)((loss: Double, x: Regularizer) => + loss + x.compute(weights, cumGradient) + ) + } + } +} + +class L1Regularizer(private val regParam: BV[Double]) extends Regularizer { + isSmooth = false + + def this(regParam: Double) = this(new BDV[Double](Array[Double](regParam))) + + def this(regParam: Vector) = this(regParam.toBreeze) + + def compute(weights: Vector, cumGradient: Vector): Double = { + val brzWeights = weights.toBreeze + val brzCumGradient = cumGradient.toBreeze + + if (regParam.length > 1) require(brzWeights.length == regParam.length) + + if (regParam.length == 1 && regParam(0) == 0.0) { + 0.0 + } + else { + var loss: Double = 0.0 + brzWeights.activeIterator.foreach { + case (_, 0.0) => // Skip explicit zero elements. + case (i, value) => { + val lambda = if (regParam.length > 1) regParam(i) else regParam(0) + loss += lambda * Math.abs(value) + brzCumGradient(i) += lambda * signum(value) + } + } + loss + } + } +} + +class L2Regularizer(private val regParam: BV[Double]) extends Regularizer { + isSmooth = true + + def this(regParam: Double) = this(new BDV[Double](Array[Double](regParam))) + + def this(regParam: Vector) = this(regParam.toBreeze) + + def compute(weights: Vector, cumGradient: Vector): Double = { + val brzWeights = weights.toBreeze + val brzCumGradient = cumGradient.toBreeze + + if (regParam.length > 1) require(brzWeights.length == regParam.length) + + if (regParam.length == 1 && regParam(0) == 0) { + 0.0 + } + else { + var loss: Double = 0.0 + brzWeights.activeIterator.foreach { + case (_, 0.0) => // Skip explicit zero elements. + case (i, value) => { + val lambda = if (regParam.length > 1) regParam(i) else regParam(0) + loss += lambda * value * value / 2.0 + brzCumGradient(i) += lambda * value + } + } + loss + } + } +} + +class ElasticNetRegularizer(private val regParam: BV[Double], private val alpha: Double) + extends CompositeRegularizer { + + def this(regParam: Double, alpha: Double) = this(new BDV[Double](Array[Double](regParam)), alpha) + + def this(regParam: Vector, alpha: Double) = this(regParam.toBreeze, alpha) + + if (alpha != 0.0) { + this.add(new L2Regularizer(Vectors.fromBreeze(this.regParam * alpha))) + } + if (alpha != 1.0) { + this.add(new L1Regularizer(Vectors.fromBreeze(this.regParam * (1.0 - alpha)))) + } +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 3ed3a5b9b384..183720a1b996 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -37,7 +37,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * The updater is responsible to also perform the update coming from the * regularization term R(w) (if any regularization is used). */ -@DeveloperApi +@Deprecated abstract class Updater extends Serializable { /** * Compute an updated value for weights given the gradient, stepSize, iteration number and @@ -66,7 +66,7 @@ abstract class Updater extends Serializable { * A simple updater for gradient descent *without* any regularization. * Uses a step-size decreasing with the square root of the number of iterations. */ -@DeveloperApi +@Deprecated class SimpleUpdater extends Updater { override def compute( weightsOld: Vector, @@ -74,8 +74,8 @@ class SimpleUpdater extends Updater { stepSize: Double, iter: Int, regParam: Double): (Vector, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + val thisIterStepSize = stepSize / math.sqrt(iter) brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) (Vectors.fromBreeze(brzWeights), 0) @@ -101,7 +101,7 @@ class SimpleUpdater extends Updater { * * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) */ -@DeveloperApi +@Deprecated class L1Updater extends Updater { override def compute( weightsOld: Vector, @@ -132,25 +132,31 @@ class L1Updater extends Updater { * R(w) = 1/2 ||w||^2 * Uses a step-size decreasing with the square root of the number of iterations. */ -@DeveloperApi +@Deprecated class SquaredL2Updater extends Updater { + private var currRegParam = 0.0 + private var regularizer: Regularizer = new SimpleRegularizer + // w' = w - thisIterStepSize * (gradient + regParam * w) override def compute( weightsOld: Vector, gradient: Vector, stepSize: Double, iter: Int, regParam: Double): (Vector, Double) = { - // add up both updates from the gradient of the loss (= step) as well as - // the gradient of the regularizer (= regParam * weightsOld) - // w' = w - thisIterStepSize * (gradient + regParam * w) - // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient - val thisIterStepSize = stepSize / math.sqrt(iter) + + if(currRegParam != regParam) { + currRegParam = regParam + regularizer = new L2Regularizer(regParam) + } + // gradient = gradient + regParam * w + val lossR = regularizer.compute(weightsOld, gradient) + + // w' = w - thisIterStepSize * gradient val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector - brzWeights :*= (1.0 - thisIterStepSize * regParam) + val thisIterStepSize = stepSize / math.sqrt(iter) brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) - val norm = brzNorm(brzWeights, 2.0) - (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm) + (Vectors.fromBreeze(brzWeights), lossR) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index a05dfc045fb8..ef6cca3b9872 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -58,11 +58,10 @@ class LassoWithSGD private ( extends GeneralizedLinearAlgorithm[LassoModel] with Serializable { private val gradient = new LeastSquaresGradient() - private val updater = new L1Updater() - override val optimizer = new GradientDescent(gradient, updater) + private val regularizer = new L1Regularizer(regParam) + override val optimizer = new GradientDescent(gradient, regularizer) .setStepSize(stepSize) .setNumIterations(numIterations) - .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 0ebad4eb58d8..565c4697c843 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -56,8 +56,8 @@ class LinearRegressionWithSGD private ( extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { private val gradient = new LeastSquaresGradient() - private val updater = new SimpleUpdater() - override val optimizer = new GradientDescent(gradient, updater) + private val regularizer = new SimpleRegularizer() + override val optimizer = new GradientDescent(gradient, regularizer) .setStepSize(stepSize) .setNumIterations(numIterations) .setMiniBatchFraction(miniBatchFraction) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index bd983bac001a..c8d198036d3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -58,12 +58,11 @@ class RidgeRegressionWithSGD private ( extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable { private val gradient = new LeastSquaresGradient() - private val updater = new SquaredL2Updater() + private val regularizer = new L2Regularizer(regParam) - override val optimizer = new GradientDescent(gradient, updater) + override val optimizer = new GradientDescent(gradient, regularizer) .setStepSize(stepSize) .setNumIterations(numIterations) - .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 951b4f7c6e6f..6bc8fdf87f11 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -72,10 +72,9 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers val initialWeights = Array(initialB) val gradient = new LogisticGradient() - val updater = new SimpleUpdater() + val regularizer = new SimpleRegularizer() val stepSize = 1.0 val numIterations = 10 - val regParam = 0 val miniBatchFrac = 1.0 // Add a extra variable consisting of all 1.0's for the intercept. @@ -90,10 +89,9 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, - updater, + regularizer, stepSize, numIterations, - regParam, miniBatchFrac, initialWeightsWithIntercept) @@ -106,7 +104,6 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers test("Test the loss and gradient of first iteration with regularization.") { val gradient = new LogisticGradient() - val updater = new SquaredL2Updater() // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) @@ -119,13 +116,13 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers // Prepare non-zero weights val initialWeightsWithIntercept = Vectors.dense(1.0, 0.5) - val regParam0 = 0 + val regularizer0 = new L2Regularizer(0.0) val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD( - dataRDD, gradient, updater, 1, 1, regParam0, 1.0, initialWeightsWithIntercept) + dataRDD, gradient, regularizer0, 1, 1, 1.0, initialWeightsWithIntercept) - val regParam1 = 1 + val regularizer1 = new L2Regularizer(1.0) val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( - dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) + dataRDD, gradient, regularizer1, 1, 1, 1.0, initialWeightsWithIntercept) def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { math.abs(x - y) / (math.abs(y) + 1e-15) < tol diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index fe7a9033cd5f..377ac5a3fc90 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -62,11 +62,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (_, loss) = LBFGS.runLBFGS( dataRDD, gradient, - simpleUpdater, + new SimpleRegularizer(), numCorrections, convergenceTol, maxNumIterations, - regParam, initialWeightsWithIntercept) // Since the cost function is convex, the loss is guaranteed to be monotonically decreasing @@ -80,10 +79,9 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (_, lossGD) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, - simpleUpdater, + new SimpleRegularizer(), stepSize, numGDIterations, - regParam, miniBatchFrac, initialWeightsWithIntercept) @@ -106,11 +104,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, gradient, - squaredL2Updater, + new L2Regularizer(regParam), numCorrections, convergenceTol, maxNumIterations, - regParam, initialWeightsWithIntercept) val numGDIterations = 50 @@ -118,10 +115,9 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (weightGD, lossGD) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, - squaredL2Updater, + new L2Regularizer(regParam), stepSize, numGDIterations, - regParam, miniBatchFrac, initialWeightsWithIntercept) @@ -151,11 +147,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (_, lossLBFGS1) = LBFGS.runLBFGS( dataRDD, gradient, - squaredL2Updater, + new L2Regularizer(regParam), numCorrections, convergenceTol, maxNumIterations, - regParam, initialWeightsWithIntercept) // Note that the first loss is computed with initial weights, @@ -166,11 +161,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (_, lossLBFGS2) = LBFGS.runLBFGS( dataRDD, gradient, - squaredL2Updater, + new L2Regularizer(regParam), numCorrections, convergenceTol, maxNumIterations, - regParam, initialWeightsWithIntercept) // Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed. @@ -181,11 +175,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (_, lossLBFGS3) = LBFGS.runLBFGS( dataRDD, gradient, - squaredL2Updater, + new L2Regularizer(regParam), numCorrections, convergenceTol, maxNumIterations, - regParam, initialWeightsWithIntercept) // With smaller convergenceTol, it takes more steps. @@ -204,11 +197,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val convergenceTol = 1e-12 val maxNumIterations = 10 - val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) + val lbfgsOptimizer = new LBFGS(gradient, new L2Regularizer(regParam)) .setNumCorrections(numCorrections) .setConvergenceTol(convergenceTol) .setMaxNumIterations(maxNumIterations) - .setRegParam(regParam) val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) @@ -217,10 +209,9 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val (weightGD, _) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, - squaredL2Updater, + new L2Regularizer(regParam), stepSize, numGDIterations, - regParam, miniBatchFrac, initialWeightsWithIntercept)