Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d46e5ed
add lbfgs as default optimizer of LinearSVC
YY-OnCall May 4, 2017
f7d5559
set owlqn as default
YY-OnCall May 6, 2017
8a7c10f
set check
YY-OnCall May 9, 2017
4ce0787
:Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Jun 12, 2017
c8afc63
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Jun 13, 2017
3707580
merge loss change
YY-OnCall Jun 13, 2017
2ffd0eb
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Jun 13, 2017
2ca5a74
fix r and python
YY-OnCall Jun 14, 2017
5f7f456
switch between Hinge and Square
YY-OnCall Jun 14, 2017
d19f619
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Jun 15, 2017
0297057
use RDDLossFunction
YY-OnCall Jun 15, 2017
15d611e
merge conflict
YY-OnCall Jun 26, 2017
a545267
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Jun 27, 2017
7be6bac
r and new ut
YY-OnCall Jun 27, 2017
aaf35ec
ut update
YY-OnCall Jun 30, 2017
ea82f35
merge conflict
YY-OnCall Sep 2, 2017
93f7b68
merge conflict and add unit tests
YY-OnCall Sep 3, 2017
cec628b
style
YY-OnCall Sep 3, 2017
55ce6b9
resolve conflict
YY-OnCall Sep 15, 2017
0f5cad5
fix python ut
YY-OnCall Sep 16, 2017
bf4d955
resolve conflict
YY-OnCall Oct 2, 2017
1f8e984
style
YY-OnCall Oct 2, 2017
a6b4cda
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Oct 3, 2017
f778f97
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Oct 14, 2017
0bb5afe
minor updates
YY-OnCall Oct 15, 2017
64bc339
Merge remote-tracking branch 'upstream/master' into svclbfgs
YY-OnCall Oct 23, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions R/pkg/R/mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @param regParam The regularization parameter. Only supports L2 regularization currently.
#' @param maxIter Maximum iteration number.
#' @param tol Convergence tolerance of iterations.
#' @param solver Optimization solver, supported options: "owlqn" or "l-bfgs". Default is "l-bfgs"
#' @param loss Loss function, supported options: "hinge" and "squared_hinge". Default is
# "squared_hinge"
#' @param standardization Whether to standardize the training features before fitting the model.
#' The coefficients of models will be always returned on the original scale,
#' so it will be transparent for users. Note that with/without
Expand Down Expand Up @@ -107,7 +110,10 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
threshold = 0.0, weightCol = NULL, aggregationDepth = 2,
handleInvalid = c("error", "keep", "skip")) {
handleInvalid = c("error", "keep", "skip"), solver = c("l-bfgs", "owlqn"),
loss = c("squared_hinge", "hinge")) {
solver <- match.arg(solver)
loss <- match.arg(loss)
formula <- paste(deparse(formula), collapse = "")

if (!is.null(weightCol) && weightCol == "") {
Expand All @@ -121,7 +127,8 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
as.numeric(tol), as.logical(standardization), as.numeric(threshold),
weightCol, as.integer(aggregationDepth), handleInvalid)
weightCol, as.integer(aggregationDepth), handleInvalid, solver,
loss)
new("LinearSVCModel", jobj = jobj)
})

Expand Down
3 changes: 2 additions & 1 deletion R/pkg/tests/fulltests/test_mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ absoluteSparkPath <- function(x) {
test_that("spark.svmLinear", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10)
model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10,
solver = "owlqn", loss = "hinge")
summary <- summary(model)

# test summary coefficients return matrix type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@

package org.apache.spark.ml.classification

import java.util.Locale

import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, OWLQN => BreezeOWLQN}
import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.optim.aggregator.HingeAggregator
import org.apache.spark.ml.optim.aggregator.{HingeAggregator, SquaredHingeAggregator}
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand All @@ -42,7 +44,26 @@ import org.apache.spark.sql.functions.{col, lit}
/** Params for linear SVM Classifier. */
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
with HasAggregationDepth with HasThreshold {
with HasAggregationDepth with HasThreshold with HasSolver {

/**
* Specifies the loss function. Currently "hinge" and "squared_hinge" are supported.
* "hinge" is the standard SVM loss (a.k.a. L1 loss) while "squared_hinge" is the square of
* the hinge loss (a.k.a. L2 loss).
*
* @see <a href="https://en.wikipedia.org/wiki/Hinge_loss">Hinge loss (Wikipedia)</a>
*
* @group param
*/
@Since("2.3.0")
final val loss: Param[String] = new Param(this, "loss", "Specifies the loss " +
"function. hinge is the standard SVM loss while squared_hinge is the square of the hinge loss.",
(s: String) => LinearSVC.supportedLoss.contains(s.toLowerCase(Locale.ROOT)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The isValid function you can use
ParamValidators.inArray[String](supportedLosses))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, IMO we need toLowerCase here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thought about this, but solver param in LinearRegression also ignore the thing. I tend to keep them consistent, what do you think of it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to support case-insensitive params in LinearRegression, or change the default behavior of ParamValidators.inArray. And we should improve the consistency in supporting case-insensitive String params anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a jira to address that issue: https://issues.apache.org/jira/browse/SPARK-22331


/** @group getParam */
@Since("2.3.0")
def getLoss: String = $(loss)


/**
* Param for threshold in binary classification prediction.
Expand All @@ -63,8 +84,11 @@ private[classification] trait LinearSVCParams extends ClassifierParams with HasR
* <a href = "https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM">
* Linear SVM Classifier</a>
*
* This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
* Only supports L2 regularization currently.
* This binary classifier implements a linear SVM classifier. Currently "hinge" and
* "squared_hinge" loss functions are supported. "hinge" is the standard SVM loss (a.k.a. L1 loss)
* while "squared_hinge" is the square of the hinge loss (a.k.a. L2 loss). Both LBFGS and OWL-QN
* optimizers are supported and can be specified via setting the solver param.
* By default, L2 SVM (Squared Hinge Loss) and L-BFGS optimizer are used.
*
*/
@Since("2.2.0")
Expand All @@ -74,6 +98,8 @@ class LinearSVC @Since("2.2.0") (
extends Classifier[Vector, LinearSVC, LinearSVCModel]
with LinearSVCParams with DefaultParamsWritable {

import LinearSVC._

@Since("2.2.0")
def this() = this(Identifiable.randomUID("linearsvc"))

Expand Down Expand Up @@ -159,6 +185,31 @@ class LinearSVC @Since("2.2.0") (
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/**
* Set the loss function. Default is "squared_hinge".
*
* @group setParam
*/
@Since("2.3.0")
def setLoss(value: String): this.type = set(loss, value)
setDefault(loss -> SQUARED_HINGE)

/**
* Set solver for LinearSVC. Supported options: "l-bfgs" and "owlqn" (case insensitve).
* - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
* optimization method.
* - "owlqn" denotes Orthant-Wise Limited-memory Quasi-Newton algorithm .
* (default: "owlqn")
* @group setParam
*/
@Since("2.3.0")
def setSolver(value: String): this.type = {
require(supportedSolvers.contains(value.toLowerCase(Locale.ROOT)), s"Solver $value was" +
s" not supported. Supported options: ${supportedSolvers.mkString(", ")}")
set(solver, value)
}
setDefault(solver -> LBFGS)

@Since("2.2.0")
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)

Expand Down Expand Up @@ -225,12 +276,27 @@ class LinearSVC @Since("2.2.0") (
None
}

val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_)
val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization,
$(aggregationDepth))
val costFun = $(loss) match {
case HINGE =>
val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_)
new RDDLossFunction(instances, getAggregatorFunc, regularization,
$(aggregationDepth))
case SQUARED_HINGE =>
val getAggregatorFunc = new SquaredHingeAggregator(bcFeaturesStd, $(fitIntercept))(_)
new RDDLossFunction(instances, getAggregatorFunc, regularization,
$(aggregationDepth))
case unexpected => throw new SparkException(
s"unexpected loss Function in LinearSVC: $unexpected")
}

val optimizer = $(solver).toLowerCase(Locale.ROOT) match {
case LBFGS => new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
case OWLQN =>
def regParamL1Fun = (index: Int) => 0D
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
case _ => throw new SparkException ("unexpected solver: " + $(solver))
}

def regParamL1Fun = (index: Int) => 0D
val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
val initialCoefWithIntercept = Vectors.zeros(numFeaturesPlusIntercept)

val states = optimizer.iterations(new CachedDiffFunction(costFun),
Expand Down Expand Up @@ -282,8 +348,27 @@ class LinearSVC @Since("2.2.0") (
@Since("2.2.0")
object LinearSVC extends DefaultParamsReadable[LinearSVC] {

/** String name for Limited-memory BFGS. */
private[classification] val LBFGS: String = "l-bfgs".toLowerCase(Locale.ROOT)

/** String name for Orthant-Wise Limited-memory Quasi-Newton. */
private[classification] val OWLQN: String = "owlqn".toLowerCase(Locale.ROOT)

/* Set of optimizers that LinearSVC supports */
private[classification] val supportedSolvers = Array(LBFGS, OWLQN)

/** String name for Hinge Loss. */
private[classification] val HINGE: String = "hinge".toLowerCase(Locale.ROOT)

/** String name for Squared Hinge Loss. */
private[classification] val SQUARED_HINGE: String = "squared_hinge".toLowerCase(Locale.ROOT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need .toLowerCase(Locale.ROOT) here ?

Copy link
Contributor Author

@hhbyyh hhbyyh Oct 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure consistency with param validation across all Locales.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. IMO these characters are all in ASCII, I think they won't encounter locales issue. (But do you encounter such issue in some env ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I never, but I cannot grantee it for all the Locales.


/* Set of loss function that LinearSVC supports */
private[classification] val supportedLoss = Array(HINGE, SQUARED_HINGE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supportedLoss ==> supportedLosses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I can update it.


@Since("2.2.0")
override def load(path: String): LinearSVC = super.load(path)

}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.optim.aggregator

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._

/**
* SquaredHingeAggregator computes the gradient and loss for squared Hinge loss function, as used in
* binary classification for instances in sparse or dense vector in an online fashion.
*
* Two SquaredHingeAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
* This class standardizes feature values during computation using bcFeaturesStd.
*
* @param bcCoefficients The coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
* @param bcFeaturesStd The standard deviation values of the features.
*/
private[ml] class SquaredHingeAggregator(
bcFeaturesStd: Broadcast[Array[Double]],
fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector])
extends DifferentiableLossAggregator[Instance, SquaredHingeAggregator] {

private val numFeatures: Int = bcFeaturesStd.value.length
private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
@transient private lazy val coefficientsArray = bcCoefficients.value match {
case DenseVector(values) => values
case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
s" but got type ${bcCoefficients.value.getClass}.")
}
protected override val dim: Int = numFeaturesPlusIntercept

/**
* Add a new training instance to this SquaredHingeAggregator, and update the loss and gradient
* of the objective function.
*
* @param instance The instance of data point to be added.
* @return This SquaredHingeAggregator object.
*/
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $numFeatures but got ${features.size}.")
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")

if (weight == 0.0) return this
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = coefficientsArray
val localGradientSumArray = gradientSumArray

val dotProduct = {
var sum = 0.0
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
sum += localCoefficients(index) * value / localFeaturesStd(index)
}
}
if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
sum
}
// Our loss function with {0, 1} labels is (max(0, 1 - (2y - 1) (f_w(x))))^2
// Therefore the gradient is 2 * ((2y - 1) f_w(x) - 1) * (2y - 1) * x
val labelScaled = 2 * label - 1.0
val scaledDoctProduct = labelScaled * dotProduct
val loss = if (1.0 > scaledDoctProduct) {
val hingeLoss = 1.0 - scaledDoctProduct
hingeLoss * hingeLoss * weight
} else {
0.0
}

if (1.0 > scaledDoctProduct) {
val gradientScale = (scaledDoctProduct - 1) * labelScaled * 2 * weight
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
}
}
if (fitIntercept) {
localGradientSumArray(localGradientSumArray.length - 1) += gradientScale
}
} // else gradient will not be updated.

lossSum += loss
weightSum += weight
this
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private[r] object LinearSVCWrapper
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"

def fit(
def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
regParam: Double,
Expand All @@ -80,7 +80,9 @@ private[r] object LinearSVCWrapper
threshold: Double,
weightCol: String,
aggregationDepth: Int,
handleInvalid: String
handleInvalid: String,
solver: String,
loss: String
): LinearSVCWrapper = {

val rFormula = new RFormula()
Expand All @@ -107,6 +109,8 @@ private[r] object LinearSVCWrapper
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
.setThreshold(threshold)
.setAggregationDepth(aggregationDepth)
.setSolver(solver)
.setLoss(loss)

if (weightCol != null) svc.setWeightCol(weightCol)

Expand Down
Loading