Skip to content

Commit 8d29001

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-13011] K-means wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin <[email protected]> Closes #11124 from yinxusen/SPARK-13011.
1 parent 15e3015 commit 8d29001

File tree

6 files changed

+203
-8
lines changed

6 files changed

+203
-8
lines changed

R/pkg/NAMESPACE

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ export("print.jobj")
1313
# MLlib integration
1414
exportMethods("glm",
1515
"predict",
16-
"summary")
16+
"summary",
17+
"kmeans",
18+
"fitted")
1719

1820
# Job group lifecycle management methods
1921
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
11601160
#' @rdname rbind
11611161
#' @export
11621162
setGeneric("rbind", signature = "...")
1163+
1164+
#' @rdname kmeans
1165+
#' @export
1166+
setGeneric("kmeans")
1167+
1168+
#' @rdname fitted
1169+
#' @export
1170+
setGeneric("fitted")

R/pkg/R/mllib.R

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"),
104104
setMethod("summary", signature(object = "PipelineModel"),
105105
function(object, ...) {
106106
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
107-
"getModelName", object@model)
107+
"getModelName", object@model)
108108
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
109-
"getModelFeatures", object@model)
109+
"getModelFeatures", object@model)
110110
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
111-
"getModelCoefficients", object@model)
111+
"getModelCoefficients", object@model)
112112
if (modelName == "LinearRegressionModel") {
113113
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
114114
"getModelDevianceResiduals", object@model)
@@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"),
119119
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
120120
rownames(coefficients) <- unlist(features)
121121
return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
122-
} else {
122+
} else if (modelName == "LogisticRegressionModel") {
123123
coefficients <- as.matrix(unlist(coefficients))
124124
colnames(coefficients) <- c("Estimate")
125125
rownames(coefficients) <- unlist(features)
126126
return(list(coefficients = coefficients))
127+
} else if (modelName == "KMeansModel") {
128+
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
129+
"getKMeansModelSize", object@model)
130+
cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
131+
"getKMeansCluster", object@model, "classes")
132+
k <- unlist(modelSize)[1]
133+
size <- unlist(modelSize)[-1]
134+
coefficients <- t(matrix(coefficients, ncol = k))
135+
colnames(coefficients) <- unlist(features)
136+
rownames(coefficients) <- 1:k
137+
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
138+
} else {
139+
stop(paste("Unsupported model", modelName, sep = " "))
140+
}
141+
})
142+
143+
#' Fit a k-means model
144+
#'
145+
#' Fit a k-means model, similarly to R's kmeans().
146+
#'
147+
#' @param x DataFrame for training
148+
#' @param centers Number of centers
149+
#' @param iter.max Maximum iteration number
150+
#' @param algorithm Algorithm choosen to fit the model
151+
#' @return A fitted k-means model
152+
#' @rdname kmeans
153+
#' @export
154+
#' @examples
155+
#'\dontrun{
156+
#' model <- kmeans(x, centers = 2, algorithm="random")
157+
#'}
158+
setMethod("kmeans", signature(x = "DataFrame"),
159+
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
160+
columnNames <- as.array(colnames(x))
161+
algorithm <- match.arg(algorithm)
162+
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
163+
algorithm, iter.max, centers, columnNames)
164+
return(new("PipelineModel", model = model))
165+
})
166+
167+
#' Get fitted result from a model
168+
#'
169+
#' Get fitted result from a model, similarly to R's fitted().
170+
#'
171+
#' @param object A fitted MLlib model
172+
#' @return DataFrame containing fitted values
173+
#' @rdname fitted
174+
#' @export
175+
#' @examples
176+
#'\dontrun{
177+
#' model <- kmeans(trainingData, 2)
178+
#' fitted.model <- fitted(model)
179+
#' showDF(fitted.model)
180+
#'}
181+
setMethod("fitted", signature(object = "PipelineModel"),
182+
function(object, method = c("centers", "classes"), ...) {
183+
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
184+
"getModelName", object@model)
185+
186+
if (modelName == "KMeansModel") {
187+
method <- match.arg(method)
188+
fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
189+
"getKMeansCluster", object@model, method)
190+
return(dataFrame(fittedResult))
191+
} else {
192+
stop(paste("Unsupported model", modelName, sep = " "))
127193
}
128194
})

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,31 @@ test_that("summary works on base GLM models", {
113113
baseSummary <- summary(baseModel)
114114
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
115115
})
116+
117+
test_that("kmeans", {
118+
newIris <- iris
119+
newIris$Species <- NULL
120+
training <- suppressWarnings(createDataFrame(sqlContext, newIris))
121+
122+
# Cache the DataFrame here to work around the bug SPARK-13178.
123+
cache(training)
124+
take(training, 1)
125+
126+
model <- kmeans(x = training, centers = 2)
127+
sample <- take(select(predict(model, training), "prediction"), 1)
128+
expect_equal(typeof(sample$prediction), "integer")
129+
expect_equal(sample$prediction, 1)
130+
131+
# Test stats::kmeans is working
132+
statsModel <- kmeans(x = newIris, centers = 2)
133+
expect_equal(unique(statsModel$cluster), c(1, 2))
134+
135+
# Test fitted works on KMeans
136+
fitted.model <- fitted(model)
137+
expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
138+
139+
# Test summary works on KMeans
140+
summary.model <- summary(model)
141+
cluster <- summary.model$cluster
142+
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
143+
})

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
1919

2020
import org.apache.hadoop.fs.Path
2121

22+
import org.apache.spark.SparkException
2223
import org.apache.spark.annotation.{Experimental, Since}
2324
import org.apache.spark.ml.{Estimator, Model}
2425
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
@@ -135,6 +136,26 @@ class KMeansModel private[ml] (
135136

136137
@Since("1.6.0")
137138
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
139+
140+
private var trainingSummary: Option[KMeansSummary] = None
141+
142+
private[clustering] def setSummary(summary: KMeansSummary): this.type = {
143+
this.trainingSummary = Some(summary)
144+
this
145+
}
146+
147+
/**
148+
* Gets summary of model on training set. An exception is
149+
* thrown if `trainingSummary == None`.
150+
*/
151+
@Since("2.0.0")
152+
def summary: KMeansSummary = trainingSummary match {
153+
case Some(summ) => summ
154+
case None =>
155+
throw new SparkException(
156+
s"No training summary available for the ${this.getClass.getSimpleName}",
157+
new NullPointerException())
158+
}
138159
}
139160

140161
@Since("1.6.0")
@@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") (
249270
.setSeed($(seed))
250271
.setEpsilon($(tol))
251272
val parentModel = algo.run(rdd)
252-
val model = new KMeansModel(uid, parentModel)
253-
copyValues(model.setParent(this))
273+
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
274+
val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
275+
model.setSummary(summary)
254276
}
255277

256278
@Since("1.5.0")
@@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
266288
override def load(path: String): KMeans = super.load(path)
267289
}
268290

291+
class KMeansSummary private[clustering] (
292+
@Since("2.0.0") @transient val predictions: DataFrame,
293+
@Since("2.0.0") val predictionCol: String,
294+
@Since("2.0.0") val featuresCol: String) extends Serializable {
295+
296+
/**
297+
* Cluster centers of the transformed data.
298+
*/
299+
@Since("2.0.0")
300+
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
301+
302+
/**
303+
* Size of each cluster.
304+
*/
305+
@Since("2.0.0")
306+
lazy val size: Array[Int] = cluster.map {
307+
case Row(clusterIdx: Int) => (clusterIdx, 1)
308+
}.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
309+
}

mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r
2020
import org.apache.spark.ml.{Pipeline, PipelineModel}
2121
import org.apache.spark.ml.attribute._
2222
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
23-
import org.apache.spark.ml.feature.RFormula
23+
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
24+
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
2425
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
2526
import org.apache.spark.sql.DataFrame
2627

@@ -51,6 +52,22 @@ private[r] object SparkRWrappers {
5152
pipeline.fit(df)
5253
}
5354

55+
def fitKMeans(
56+
df: DataFrame,
57+
initMode: String,
58+
maxIter: Double,
59+
k: Double,
60+
columns: Array[String]): PipelineModel = {
61+
val assembler = new VectorAssembler().setInputCols(columns)
62+
val kMeans = new KMeans()
63+
.setInitMode(initMode)
64+
.setMaxIter(maxIter.toInt)
65+
.setK(k.toInt)
66+
.setFeaturesCol(assembler.getOutputCol)
67+
val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
68+
pipeline.fit(df)
69+
}
70+
5471
def getModelCoefficients(model: PipelineModel): Array[Double] = {
5572
model.stages.last match {
5673
case m: LinearRegressionModel => {
@@ -72,6 +89,8 @@ private[r] object SparkRWrappers {
7289
m.coefficients.toArray
7390
}
7491
}
92+
case m: KMeansModel =>
93+
m.clusterCenters.flatMap(_.toArray)
7594
}
7695
}
7796

@@ -85,6 +104,31 @@ private[r] object SparkRWrappers {
85104
}
86105
}
87106

107+
def getKMeansModelSize(model: PipelineModel): Array[Int] = {
108+
model.stages.last match {
109+
case m: KMeansModel => Array(m.getK) ++ m.summary.size
110+
case other => throw new UnsupportedOperationException(
111+
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
112+
}
113+
}
114+
115+
def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
116+
model.stages.last match {
117+
case m: KMeansModel =>
118+
if (method == "centers") {
119+
// Drop the assembled vector for easy-print to R side.
120+
m.summary.predictions.drop(m.summary.featuresCol)
121+
} else if (method == "classes") {
122+
m.summary.cluster
123+
} else {
124+
throw new UnsupportedOperationException(
125+
s"Method (centers or classes) required but $method found.")
126+
}
127+
case other => throw new UnsupportedOperationException(
128+
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
129+
}
130+
}
131+
88132
def getModelFeatures(model: PipelineModel): Array[String] = {
89133
model.stages.last match {
90134
case m: LinearRegressionModel =>
@@ -103,6 +147,10 @@ private[r] object SparkRWrappers {
103147
} else {
104148
attrs.attributes.get.map(_.name.get)
105149
}
150+
case m: KMeansModel =>
151+
val attrs = AttributeGroup.fromStructField(
152+
m.summary.predictions.schema(m.summary.featuresCol))
153+
attrs.attributes.get.map(_.name.get)
106154
}
107155
}
108156

@@ -112,6 +160,8 @@ private[r] object SparkRWrappers {
112160
"LinearRegressionModel"
113161
case m: LogisticRegressionModel =>
114162
"LogisticRegressionModel"
163+
case m: KMeansModel =>
164+
"KMeansModel"
115165
}
116166
}
117167
}

0 commit comments

Comments
 (0)