-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-20602] [ML]Adding LBFGS optimizer and Squared_hinge loss for LinearSVC #17862
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d46e5ed
f7d5559
8a7c10f
4ce0787
c8afc63
3707580
2ffd0eb
2ca5a74
5f7f456
d19f619
0297057
15d611e
a545267
7be6bac
aaf35ec
ea82f35
93f7b68
cec628b
55ce6b9
0f5cad5
bf4d955
1f8e984
a6b4cda
f778f97
0bb5afe
64bc339
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,18 +17,20 @@ | |
|
|
||
| package org.apache.spark.ml.classification | ||
|
|
||
| import java.util.Locale | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import breeze.linalg.{DenseVector => BDV} | ||
| import breeze.optimize.{CachedDiffFunction, OWLQN => BreezeOWLQN} | ||
| import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} | ||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.linalg._ | ||
| import org.apache.spark.ml.optim.aggregator.HingeAggregator | ||
| import org.apache.spark.ml.optim.aggregator.{HingeAggregator, SquaredHingeAggregator} | ||
| import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared._ | ||
|
|
@@ -42,7 +44,26 @@ import org.apache.spark.sql.functions.{col, lit} | |
| /** Params for linear SVM Classifier. */ | ||
| private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam | ||
| with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol | ||
| with HasAggregationDepth with HasThreshold { | ||
| with HasAggregationDepth with HasThreshold with HasSolver { | ||
|
|
||
| /** | ||
| * Specifies the loss function. Currently "hinge" and "squared_hinge" are supported. | ||
| * "hinge" is the standard SVM loss (a.k.a. L1 loss) while "squared_hinge" is the square of | ||
| * the hinge loss (a.k.a. L2 loss). | ||
| * | ||
| * @see <a href="https://en.wikipedia.org/wiki/Hinge_loss">Hinge loss (Wikipedia)</a> | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| final val loss: Param[String] = new Param(this, "loss", "Specifies the loss " + | ||
| "function. hinge is the standard SVM loss while squared_hinge is the square of the hinge loss.", | ||
| (s: String) => LinearSVC.supportedLoss.contains(s.toLowerCase(Locale.ROOT))) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getLoss: String = $(loss) | ||
|
|
||
|
|
||
| /** | ||
| * Param for threshold in binary classification prediction. | ||
|
|
@@ -63,8 +84,11 @@ private[classification] trait LinearSVCParams extends ClassifierParams with HasR | |
| * <a href = "https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM"> | ||
| * Linear SVM Classifier</a> | ||
| * | ||
| * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer. | ||
| * Only supports L2 regularization currently. | ||
| * This binary classifier implements a linear SVM classifier. Currently "hinge" and | ||
| * "squared_hinge" loss functions are supported. "hinge" is the standard SVM loss (a.k.a. L1 loss) | ||
| * while "squared_hinge" is the square of the hinge loss (a.k.a. L2 loss). Both LBFGS and OWL-QN | ||
| * optimizers are supported and can be specified via setting the solver param. | ||
| * By default, L2 SVM (Squared Hinge Loss) and L-BFGS optimizer are used. | ||
| * | ||
| */ | ||
| @Since("2.2.0") | ||
|
|
@@ -74,6 +98,8 @@ class LinearSVC @Since("2.2.0") ( | |
| extends Classifier[Vector, LinearSVC, LinearSVCModel] | ||
| with LinearSVCParams with DefaultParamsWritable { | ||
|
|
||
| import LinearSVC._ | ||
|
|
||
| @Since("2.2.0") | ||
| def this() = this(Identifiable.randomUID("linearsvc")) | ||
|
|
||
|
|
@@ -159,6 +185,31 @@ class LinearSVC @Since("2.2.0") ( | |
| def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) | ||
| setDefault(aggregationDepth -> 2) | ||
|
|
||
| /** | ||
| * Set the loss function. Default is "squared_hinge". | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setLoss(value: String): this.type = set(loss, value) | ||
| setDefault(loss -> SQUARED_HINGE) | ||
|
|
||
| /** | ||
| * Set solver for LinearSVC. Supported options: "l-bfgs" and "owlqn" (case insensitve). | ||
| * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton | ||
| * optimization method. | ||
| * - "owlqn" denotes Orthant-Wise Limited-memory Quasi-Newton algorithm . | ||
| * (default: "owlqn") | ||
| * @group setParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setSolver(value: String): this.type = { | ||
| require(supportedSolvers.contains(value.toLowerCase(Locale.ROOT)), s"Solver $value was" + | ||
| s" not supported. Supported options: ${supportedSolvers.mkString(", ")}") | ||
| set(solver, value) | ||
| } | ||
| setDefault(solver -> LBFGS) | ||
|
|
||
| @Since("2.2.0") | ||
| override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) | ||
|
|
||
|
|
@@ -225,12 +276,27 @@ class LinearSVC @Since("2.2.0") ( | |
| None | ||
| } | ||
|
|
||
| val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_) | ||
| val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization, | ||
| $(aggregationDepth)) | ||
| val costFun = $(loss) match { | ||
| case HINGE => | ||
| val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_) | ||
| new RDDLossFunction(instances, getAggregatorFunc, regularization, | ||
| $(aggregationDepth)) | ||
| case SQUARED_HINGE => | ||
| val getAggregatorFunc = new SquaredHingeAggregator(bcFeaturesStd, $(fitIntercept))(_) | ||
| new RDDLossFunction(instances, getAggregatorFunc, regularization, | ||
| $(aggregationDepth)) | ||
| case unexpected => throw new SparkException( | ||
| s"unexpected loss Function in LinearSVC: $unexpected") | ||
| } | ||
|
|
||
| val optimizer = $(solver).toLowerCase(Locale.ROOT) match { | ||
| case LBFGS => new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) | ||
| case OWLQN => | ||
| def regParamL1Fun = (index: Int) => 0D | ||
| new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) | ||
| case _ => throw new SparkException ("unexpected solver: " + $(solver)) | ||
| } | ||
|
|
||
| def regParamL1Fun = (index: Int) => 0D | ||
| val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) | ||
| val initialCoefWithIntercept = Vectors.zeros(numFeaturesPlusIntercept) | ||
|
|
||
| val states = optimizer.iterations(new CachedDiffFunction(costFun), | ||
|
|
@@ -282,8 +348,27 @@ class LinearSVC @Since("2.2.0") ( | |
| @Since("2.2.0") | ||
| object LinearSVC extends DefaultParamsReadable[LinearSVC] { | ||
|
|
||
| /** String name for Limited-memory BFGS. */ | ||
| private[classification] val LBFGS: String = "l-bfgs".toLowerCase(Locale.ROOT) | ||
|
|
||
| /** String name for Orthant-Wise Limited-memory Quasi-Newton. */ | ||
| private[classification] val OWLQN: String = "owlqn".toLowerCase(Locale.ROOT) | ||
|
|
||
| /* Set of optimizers that LinearSVC supports */ | ||
| private[classification] val supportedSolvers = Array(LBFGS, OWLQN) | ||
|
|
||
| /** String name for Hinge Loss. */ | ||
| private[classification] val HINGE: String = "hinge".toLowerCase(Locale.ROOT) | ||
|
|
||
| /** String name for Squared Hinge Loss. */ | ||
| private[classification] val SQUARED_HINGE: String = "squared_hinge".toLowerCase(Locale.ROOT) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To ensure consistency with param validation across all Locales.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. IMO these characters are all in ASCII, I think they won't encounter locales issue. (But do you encounter such issue in some env ?)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally I never, but I cannot grantee it for all the Locales. |
||
|
|
||
| /* Set of loss function that LinearSVC supports */ | ||
| private[classification] val supportedLoss = Array(HINGE, SQUARED_HINGE) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. supportedLoss ==> supportedLosses
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I can update it. |
||
|
|
||
| @Since("2.2.0") | ||
| override def load(path: String): LinearSVC = super.load(path) | ||
|
|
||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.optim.aggregator | ||
|
|
||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.linalg._ | ||
|
|
||
| /** | ||
| * SquaredHingeAggregator computes the gradient and loss for squared Hinge loss function, as used in | ||
| * binary classification for instances in sparse or dense vector in an online fashion. | ||
| * | ||
| * Two SquaredHingeAggregator can be merged together to have a summary of loss and gradient of | ||
| * the corresponding joint dataset. | ||
| * | ||
| * This class standardizes feature values during computation using bcFeaturesStd. | ||
| * | ||
| * @param bcCoefficients The coefficients corresponding to the features. | ||
| * @param fitIntercept Whether to fit an intercept term. | ||
| * @param bcFeaturesStd The standard deviation values of the features. | ||
| */ | ||
| private[ml] class SquaredHingeAggregator( | ||
| bcFeaturesStd: Broadcast[Array[Double]], | ||
| fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector]) | ||
| extends DifferentiableLossAggregator[Instance, SquaredHingeAggregator] { | ||
|
|
||
| private val numFeatures: Int = bcFeaturesStd.value.length | ||
| private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures | ||
| @transient private lazy val coefficientsArray = bcCoefficients.value match { | ||
| case DenseVector(values) => values | ||
| case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + | ||
| s" but got type ${bcCoefficients.value.getClass}.") | ||
| } | ||
| protected override val dim: Int = numFeaturesPlusIntercept | ||
|
|
||
| /** | ||
| * Add a new training instance to this SquaredHingeAggregator, and update the loss and gradient | ||
| * of the objective function. | ||
| * | ||
| * @param instance The instance of data point to be added. | ||
| * @return This SquaredHingeAggregator object. | ||
| */ | ||
| def add(instance: Instance): this.type = { | ||
| instance match { case Instance(label, weight, features) => | ||
| require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + | ||
| s" Expecting $numFeatures but got ${features.size}.") | ||
| require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") | ||
|
|
||
| if (weight == 0.0) return this | ||
| val localFeaturesStd = bcFeaturesStd.value | ||
| val localCoefficients = coefficientsArray | ||
| val localGradientSumArray = gradientSumArray | ||
|
|
||
| val dotProduct = { | ||
| var sum = 0.0 | ||
| features.foreachActive { (index, value) => | ||
| if (localFeaturesStd(index) != 0.0 && value != 0.0) { | ||
| sum += localCoefficients(index) * value / localFeaturesStd(index) | ||
| } | ||
| } | ||
| if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1) | ||
| sum | ||
| } | ||
| // Our loss function with {0, 1} labels is (max(0, 1 - (2y - 1) (f_w(x))))^2 | ||
| // Therefore the gradient is 2 * ((2y - 1) f_w(x) - 1) * (2y - 1) * x | ||
| val labelScaled = 2 * label - 1.0 | ||
| val scaledDoctProduct = labelScaled * dotProduct | ||
| val loss = if (1.0 > scaledDoctProduct) { | ||
| val hingeLoss = 1.0 - scaledDoctProduct | ||
| hingeLoss * hingeLoss * weight | ||
| } else { | ||
| 0.0 | ||
| } | ||
|
|
||
| if (1.0 > scaledDoctProduct) { | ||
| val gradientScale = (scaledDoctProduct - 1) * labelScaled * 2 * weight | ||
| features.foreachActive { (index, value) => | ||
| if (localFeaturesStd(index) != 0.0 && value != 0.0) { | ||
| localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index) | ||
| } | ||
| } | ||
| if (fitIntercept) { | ||
| localGradientSumArray(localGradientSumArray.length - 1) += gradientScale | ||
| } | ||
| } // else gradient will not be updated. | ||
|
|
||
| lossSum += loss | ||
| weightSum += weight | ||
| this | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
isValidfunction you can useParamValidators.inArray[String](supportedLosses))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, IMO we need toLowerCase here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I thought about this, but
solverparam inLinearRegressionalso ignore the thing. I tend to keep them consistent, what do you think of it ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to support case-insensitive params in
LinearRegression, or change the default behavior of ParamValidators.inArray. And we should improve the consistency in supporting case-insensitive String params anyway.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a jira to address that issue: https://issues.apache.org/jira/browse/SPARK-22331