Skip to content

Commit ad09e4c

Browse files
committed
[MINOR][SPARKR][ML] Joint coefficients with intercept for SparkR linear SVM summary.
## What changes were proposed in this pull request? Joint coefficients with intercept for SparkR linear SVM summary. ## How was this patch tested? Existing tests. Author: Yanbo Liang <[email protected]> Closes #18035 from yanboliang/svm-r.
1 parent 442287a commit ad09e4c

File tree

3 files changed

+26
-27
lines changed

3 files changed

+26
-27
lines changed

R/pkg/R/mllib_classification.R

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
4646
#' @note NaiveBayesModel since 2.0.0
4747
setClass("NaiveBayesModel", representation(jobj = "jobj"))
4848

49-
#' linear SVM Model
49+
#' Linear SVM Model
5050
#'
51-
#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package
51+
#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
52+
#' Currently only supports binary classification model with linear kernel.
5253
#' Users can print, make predictions on the produced model and save the model to the input path.
5354
#'
5455
#' @param data SparkDataFrame for training.
5556
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
5657
#' operators are supported, including '~', '.', ':', '+', and '-'.
57-
#' @param regParam The regularization parameter.
58+
#' @param regParam The regularization parameter. Only supports L2 regularization currently.
5859
#' @param maxIter Maximum iteration number.
5960
#' @param tol Convergence tolerance of iterations.
6061
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients
@@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
111112
new("LinearSVCModel", jobj = jobj)
112113
})
113114

114-
# Predicted values based on an LinearSVCModel model
115+
# Predicted values based on a LinearSVCModel model
115116

116117
#' @param newData a SparkDataFrame for testing.
117-
#' @return \code{predict} returns the predicted values based on an LinearSVCModel.
118+
#' @return \code{predict} returns the predicted values based on a LinearSVCModel.
118119
#' @rdname spark.svmLinear
119120
#' @aliases predict,LinearSVCModel,SparkDataFrame-method
120121
#' @export
@@ -124,36 +125,27 @@ setMethod("predict", signature(object = "LinearSVCModel"),
124125
predict_internal(object, newData)
125126
})
126127

127-
# Get the summary of an LinearSVCModel
128+
# Get the summary of a LinearSVCModel
128129

129-
#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}.
130+
#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}.
130131
#' @return \code{summary} returns summary information of the fitted model, which is a list.
131132
#' The list includes \code{coefficients} (coefficients of the fitted model),
132-
#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes),
133-
#' \code{numFeatures} (number of features).
133+
#' \code{numClasses} (number of classes), \code{numFeatures} (number of features).
134134
#' @rdname spark.svmLinear
135135
#' @aliases summary,LinearSVCModel-method
136136
#' @export
137137
#' @note summary(LinearSVCModel) since 2.2.0
138138
setMethod("summary", signature(object = "LinearSVCModel"),
139139
function(object) {
140140
jobj <- object@jobj
141-
features <- callJMethod(jobj, "features")
142-
labels <- callJMethod(jobj, "labels")
143-
coefficients <- callJMethod(jobj, "coefficients")
144-
nCol <- length(coefficients) / length(features)
145-
coefficients <- matrix(unlist(coefficients), ncol = nCol)
146-
intercept <- callJMethod(jobj, "intercept")
141+
features <- callJMethod(jobj, "rFeatures")
142+
coefficients <- callJMethod(jobj, "rCoefficients")
143+
coefficients <- as.matrix(unlist(coefficients))
144+
colnames(coefficients) <- c("Estimate")
145+
rownames(coefficients) <- unlist(features)
147146
numClasses <- callJMethod(jobj, "numClasses")
148147
numFeatures <- callJMethod(jobj, "numFeatures")
149-
if (nCol == 1) {
150-
colnames(coefficients) <- c("Estimate")
151-
} else {
152-
colnames(coefficients) <- unlist(labels)
153-
}
154-
rownames(coefficients) <- unlist(features)
155-
list(coefficients = coefficients, intercept = intercept,
156-
numClasses = numClasses, numFeatures = numFeatures)
148+
list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures)
157149
})
158150

159151
# Save fitted LinearSVCModel to the input path

R/pkg/inst/tests/testthat/test_mllib_classification.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ test_that("spark.svmLinear", {
3838
expect_true(class(summary$coefficients[, 1]) == "numeric")
3939

4040
coefs <- summary$coefficients[, "Estimate"]
41-
expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085)
41+
expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085)
4242
expect_true(all(abs(coefs - expected_coefs) < 0.1))
43-
expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2)
4443

4544
# Test prediction with string label
4645
prediction <- predict(model, training)

mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,17 @@ private[r] class LinearSVCWrapper private (
3838
private val svcModel: LinearSVCModel =
3939
pipeline.stages(1).asInstanceOf[LinearSVCModel]
4040

41-
lazy val coefficients: Array[Double] = svcModel.coefficients.toArray
41+
lazy val rFeatures: Array[String] = if (svcModel.getFitIntercept) {
42+
Array("(Intercept)") ++ features
43+
} else {
44+
features
45+
}
4246

43-
lazy val intercept: Double = svcModel.intercept
47+
lazy val rCoefficients: Array[Double] = if (svcModel.getFitIntercept) {
48+
Array(svcModel.intercept) ++ svcModel.coefficients.toArray
49+
} else {
50+
svcModel.coefficients.toArray
51+
}
4452

4553
lazy val numClasses: Int = svcModel.numClasses
4654

0 commit comments

Comments
 (0)