@@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
3232# ' @export
3333setClass ("AFTSurvivalRegressionModel ", representation(jobj = "jobj"))
3434
35+ # ' @title S4 class that represents a KMeansModel
36+ # ' @param jobj a Java object reference to the backing Scala KMeansModel
37+ # ' @export
38+ setClass ("KMeansModel ", representation(jobj = "jobj"))
39+
3540# ' Fits a generalized linear model
3641# '
3742# ' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
154159 colnames(coefficients ) <- c(" Estimate" )
155160 rownames(coefficients ) <- unlist(features )
156161 return (list (coefficients = coefficients ))
157- } else if (modelName == " KMeansModel" ) {
158- modelSize <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
159- " getKMeansModelSize" , object @ model )
160- cluster <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
161- " getKMeansCluster" , object @ model , " classes" )
162- k <- unlist(modelSize )[1 ]
163- size <- unlist(modelSize )[- 1 ]
164- coefficients <- t(matrix (coefficients , ncol = k ))
165- colnames(coefficients ) <- unlist(features )
166- rownames(coefficients ) <- 1 : k
167- return (list (coefficients = coefficients , size = size , cluster = dataFrame(cluster )))
168162 } else {
169163 stop(paste(" Unsupported model" , modelName , sep = " " ))
170164 }
@@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
213207# ' @examples
214208# ' \dontrun{
215209# ' model <- kmeans(x, centers = 2, algorithm="random")
216- # '}
210+ # ' }
217211setMethod ("kmeans ", signature(x = "DataFrame"),
218212 function (x , centers , iter.max = 10 , algorithm = c(" random" , " k-means||" )) {
219213 columnNames <- as.array(colnames(x ))
220214 algorithm <- match.arg(algorithm )
221- model <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers " , " fitKMeans " , x @ sdf ,
222- algorithm , iter.max , centers , columnNames )
223- return (new(" PipelineModel " , model = model ))
215+ jobj <- callJStatic(" org.apache.spark.ml.r.KMeansWrapper " , " fit " , x @ sdf ,
216+ centers , iter.max , algorithm , columnNames )
217+ return (new(" KMeansModel " , jobj = jobj ))
224218 })
225219
226- # ' Get fitted result from a model
220+ # ' Get fitted result from a k-means model
227221# '
228- # ' Get fitted result from a model, similarly to R's fitted().
222+ # ' Get fitted result from a k-means model, similarly to R's fitted().
229223# '
230- # ' @param object A fitted MLlib model
224+ # ' @param object A fitted k-means model
231225# ' @return DataFrame containing fitted values
232226# ' @rdname fitted
233227# ' @export
@@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
237231# ' fitted.model <- fitted(model)
238232# ' showDF(fitted.model)
239233# '}
240- setMethod ("fitted ", signature(object = "PipelineModel "),
234+ setMethod ("fitted ", signature(object = "KMeansModel "),
241235 function (object , method = c(" centers" , " classes" ), ... ) {
242- modelName <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
243- " getModelName" , object @ model )
236+ method <- match.arg(method )
237+ return (dataFrame(callJMethod(object @ jobj , " fitted" , method )))
238+ })
244239
245- if (modelName == " KMeansModel" ) {
246- method <- match.arg(method )
247- fittedResult <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
248- " getKMeansCluster" , object @ model , method )
249- return (dataFrame(fittedResult ))
250- } else {
251- stop(paste(" Unsupported model" , modelName , sep = " " ))
252- }
240+ # ' Get the summary of a k-means model
241+ # '
242+ # ' Returns the summary of a k-means model produced by kmeans(),
243+ # ' similarly to R's summary().
244+ # '
245+ # ' @param object a fitted k-means model
246+ # ' @return the model's coefficients, size and cluster
247+ # ' @rdname summary
248+ # ' @export
249+ # ' @examples
250+ # ' \dontrun{
251+ # ' model <- kmeans(trainingData, 2)
252+ # ' summary(model)
253+ # ' }
254+ setMethod ("summary ", signature(object = "KMeansModel"),
255+ function (object , ... ) {
256+ jobj <- object @ jobj
257+ features <- callJMethod(jobj , " features" )
258+ coefficients <- callJMethod(jobj , " coefficients" )
259+ cluster <- callJMethod(jobj , " cluster" )
260+ k <- callJMethod(jobj , " k" )
261+ size <- callJMethod(jobj , " size" )
262+ coefficients <- t(matrix (coefficients , ncol = k ))
263+ colnames(coefficients ) <- unlist(features )
264+ rownames(coefficients ) <- 1 : k
265+ return (list (coefficients = coefficients , size = size , cluster = dataFrame(cluster )))
266+ })
267+
268+ # ' Make predictions from a k-means model
269+ # '
270+ # ' Make predictions from a model produced by kmeans().
271+ # '
272+ # ' @param object A fitted k-means model
273+ # ' @param newData DataFrame for testing
274+ # ' @return DataFrame containing predicted labels in a column named "prediction"
275+ # ' @rdname predict
276+ # ' @export
277+ # ' @examples
278+ # ' \dontrun{
279+ # ' model <- kmeans(trainingData, 2)
280+ # ' predicted <- predict(model, testData)
281+ # ' showDF(predicted)
282+ # ' }
283+ setMethod ("predict ", signature(object = "KMeansModel"),
284+ function (object , newData ) {
285+ return (dataFrame(callJMethod(object @ jobj , " transform" , newData @ sdf )))
253286 })
254287
255288# ' Fit a Bernoulli naive Bayes model
0 commit comments