diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index f76b14eeeb542..3889c4e8eb9aa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import java.util.Locale
+
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
@@ -42,7 +44,26 @@ import org.apache.spark.sql.functions.{col, lit}
/** Params for linear SVM Classifier. */
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
- with HasThreshold with HasAggregationDepth
+ with HasThreshold with HasAggregationDepth {
+
+ /**
+ * Specifies the loss function. Currently "hinge" and "squared_hinge" are supported.
+ * "hinge" is the standard SVM loss (a.k.a. L1 loss) while "squared_hinge" is the square of
+ * the hinge loss (a.k.a. L2 loss).
+ *
+ * @see Hinge loss (Wikipedia)
+ *
+ * @group param
+ */
+ @Since("2.3.0")
+ final val lossFunction: Param[String] = new Param(this, "lossFunction", "Specifies the loss " +
+ "function. hinge is the standard SVM loss while squared_hinge is the square of the hinge loss.",
+ (s: String) => LinearSVC.supportedLoss.contains(s.toLowerCase(Locale.ROOT)))
+
+ /** @group getParam */
+ @Since("2.3.0")
+ def getLossFunction: String = $(lossFunction)
+}
/**
* :: Experimental ::
@@ -50,7 +71,8 @@ private[classification] trait LinearSVCParams extends ClassifierParams with HasR
*
* Linear SVM Classifier
*
- * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
+ * This binary classifier optimizes the Hinge Loss (or Squared Hinge Loss) using the
+ * OWLQN optimizer.
*
*/
@Since("2.2.0")
@@ -63,6 +85,15 @@ class LinearSVC @Since("2.2.0") (
@Since("2.2.0")
def this() = this(Identifiable.randomUID("linearsvc"))
+ /**
+ * Set the loss function. Default is "hinge".
+ *
+ * @group setParam
+ */
+ @Since("2.3.0")
+ def setLossFunction(value: String): this.type = set(lossFunction, value)
+ setDefault(lossFunction -> "hinge")
+
/**
* Set the regularization parameter.
* Default is 0.0.
@@ -202,8 +233,8 @@ class LinearSVC @Since("2.2.0") (
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
val regParamL2 = $(regParam)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
- val costFun = new LinearSVCCostFun(instances, $(fitIntercept),
- $(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
+ val costFun = new LinearSVCCostFun(instances, $(fitIntercept), $(standardization),
+ bcFeaturesStd, regParamL2, $(aggregationDepth), $(lossFunction)toLowerCase(Locale.ROOT))
def regParamL1Fun = (index: Int) => 0D
val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
@@ -260,6 +291,8 @@ object LinearSVC extends DefaultParamsReadable[LinearSVC] {
@Since("2.2.0")
override def load(path: String): LinearSVC = super.load(path)
+
+ private[classification] val supportedLoss = Array("hinge", "squared_hinge")
}
/**
@@ -355,7 +388,8 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
}
/**
- * LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function
+ * LinearSVCCostFun implements Breeze's DiffFunction[T] for loss function ("hinge" or
+ * "squared_hinge").
*/
private class LinearSVCCostFun(
instances: RDD[Instance],
@@ -363,7 +397,8 @@ private class LinearSVCCostFun(
standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
regParamL2: Double,
- aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
+ aggregationDepth: Int,
+ lossFunction: String) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val coeffs = Vectors.fromBreeze(coefficients)
@@ -376,7 +411,7 @@ private class LinearSVCCostFun(
val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2)
instances.treeAggregate(
- new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept)
+ new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept, lossFunction)
)(seqOp, combOp, aggregationDepth)
}
@@ -421,8 +456,9 @@ private class LinearSVCCostFun(
}
/**
- * LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
- * in binary classification for instances in sparse or dense vector in an online fashion.
+ * LinearSVCAggregator computes the gradient and loss for loss function ("hinge" or
+ * "squared_hinge"), as used in binary classification for instances in sparse or dense vector
+ * in an online fashion.
*
* Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
@@ -436,7 +472,8 @@ private class LinearSVCCostFun(
private class LinearSVCAggregator(
bcCoefficients: Broadcast[Vector],
bcFeaturesStd: Broadcast[Array[Double]],
- fitIntercept: Boolean) extends Serializable {
+ fitIntercept: Boolean,
+ lossFunction: String) extends Serializable {
private val numFeatures: Int = bcFeaturesStd.value.length
private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
@@ -477,16 +514,26 @@ private class LinearSVCAggregator(
sum
}
// Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
- // Therefore the gradient is -(2y - 1)*x
val labelScaled = 2 * label - 1.0
val loss = if (1.0 > labelScaled * dotProduct) {
- weight * (1.0 - labelScaled * dotProduct)
+ val hingeLoss = 1.0 - labelScaled * dotProduct
+ lossFunction match {
+ case "hinge" => hingeLoss * weight
+ case "squared_hinge" => hingeLoss * hingeLoss * weight
+ case unexpected => throw new SparkException(
+ s"unexpected lossFunction in LinearSVCAggregator: $unexpected")
+ }
} else {
0.0
}
if (1.0 > labelScaled * dotProduct) {
- val gradientScale = -labelScaled * weight
+ val gradientScale = lossFunction match {
+ case "hinge" => -labelScaled * weight
+ case "squared_hinge" => (labelScaled * dotProduct - 1) * labelScaled * 2
+ case unexpected => throw new SparkException(
+ s"unexpected lossFunction in LinearSVCAggregator: $unexpected")
+ }
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index 2f87afc23fe7e..0e13d8a0ef210 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -75,12 +75,14 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("Linear SVC binary classification") {
- val svm = new LinearSVC()
- val model = svm.fit(smallBinaryDataset)
- assert(model.transform(smallValidationDataset)
- .where("prediction=label").count() > nPoints * 0.8)
- val sparseModel = svm.fit(smallSparseBinaryDataset)
- checkModels(model, sparseModel)
+ Array("hinge", "squared_hinge").foreach { loss =>
+ val svm = new LinearSVC().setLossFunction(loss)
+ val model = svm.fit(smallBinaryDataset)
+ assert(model.transform(smallValidationDataset)
+ .where("prediction=label").count() > nPoints * 0.8)
+ val sparseModel = svm.fit(smallSparseBinaryDataset)
+ checkModels(model, sparseModel)
+ }
}
test("Linear SVC binary classification with regularization") {
@@ -100,6 +102,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("linear svc: default params") {
val lsvc = new LinearSVC()
+ assert(lsvc.getLossFunction === "hinge")
assert(lsvc.getRegParam === 0.0)
assert(lsvc.getMaxIter === 100)
assert(lsvc.getFitIntercept)
@@ -116,6 +119,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
model.transform(smallBinaryDataset)
.select("label", "prediction", "rawPrediction")
.collect()
+ assert(model.getLossFunction === "hinge")
assert(model.getThreshold === 0.0)
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
@@ -125,6 +129,14 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model.numFeatures === 2)
MLTestingUtils.checkCopyAndUids(lsvc, model)
+
+ withClue("lossFunction should be case-insensitive") {
+ lsvc.setLossFunction("HINGE")
+ lsvc.setLossFunction("Squared_hinge")
+ intercept[IllegalArgumentException] {
+ val model = lsvc.setLossFunction("hing")
+ }
+ }
}
test("linear svc doesn't fit intercept when fitIntercept is off") {
@@ -140,7 +152,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("sparse coefficients in SVCAggregator") {
val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0))
- val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true)
+ val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true, "hinge")
val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") {
intercept[IllegalArgumentException] {
agg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))
@@ -168,7 +180,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}
- test("linearSVC comparison with R e1071 and scikit-learn") {
+ test("linearSVC with hinge loss comparison with R e1071 and scikit-learn (liblinear)") {
val trainer1 = new LinearSVC()
.setRegParam(0.00002) // set regParam = 2.0 / datasize / c
.setMaxIter(200)
@@ -223,6 +235,38 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model1.coefficients ~== coefficientsSK relTol 4E-3)
}
+ test("linearSVC with squared_hinge loss comparison with scikit-learn (liblinear)") {
+ val linearSVC = new LinearSVC()
+ .setLossFunction("squared_hinge")
+ .setRegParam(2.0 / 10 / 1000) // set regParam = 2.0 / datasize / c
+ .setMaxIter(80)
+ .setTol(1e-4)
+ val model = linearSVC.fit(binaryDataset.limit(1000))
+
+ /*
+ Use the following python code to load the data and train the model using scikit-learn package.
+
+ import numpy as np
+ from sklearn import svm
+ f = open("path/spark/assembly/target/tmp/LinearSVC/binaryDataset/part-00000")
+ data = np.loadtxt(f, delimiter=",")[:1000]
+ X = data[:, 1:] # select columns 1 through end
+ y = data[:, 0] # select column 0 as label
+ clf = svm.LinearSVC(fit_intercept=True, C=10, loss='squared_hinge', tol=1e-4, random_state=42)
+ m = clf.fit(X, y)
+ print m.coef_
+ print m.intercept_
+
+ [[ 2.85136074 6.25310456 9.00668415 12.17750981]]
+ [ 2.93419973]
+ */
+
+ val coefficientsSK = Vectors.dense(2.85136074, 6.25310456, 9.00668415, 12.17750981)
+ val interceptSK = 2.93419973
+ assert(model.intercept ~== interceptSK relTol 2E-2)
+ assert(model.coefficients ~== coefficientsSK relTol 2E-2)
+ }
+
test("read/write: SVM") {
def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = {
assert(model.intercept === model2.intercept)
@@ -238,6 +282,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
object LinearSVCSuite {
val allParamSettings: Map[String, Any] = Map(
+ "lossFunction" -> "squared_hinge",
"regParam" -> 0.01,
"maxIter" -> 2, // intentionally small
"fitIntercept" -> true,