@@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"),
104104setMethod ("summary ", signature(object = "PipelineModel"),
105105 function (object , ... ) {
106106 modelName <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
107- " getModelName" , object @ model )
107+ " getModelName" , object @ model )
108108 features <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
109- " getModelFeatures" , object @ model )
109+ " getModelFeatures" , object @ model )
110110 coefficients <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
111- " getModelCoefficients" , object @ model )
111+ " getModelCoefficients" , object @ model )
112112 if (modelName == " LinearRegressionModel" ) {
113113 devianceResiduals <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
114114 " getModelDevianceResiduals" , object @ model )
@@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"),
119119 colnames(coefficients ) <- c(" Estimate" , " Std. Error" , " t value" , " Pr(>|t|)" )
120120 rownames(coefficients ) <- unlist(features )
121121 return (list (devianceResiduals = devianceResiduals , coefficients = coefficients ))
122- } else {
122+ } else if ( modelName == " LogisticRegressionModel " ) {
123123 coefficients <- as.matrix(unlist(coefficients ))
124124 colnames(coefficients ) <- c(" Estimate" )
125125 rownames(coefficients ) <- unlist(features )
126126 return (list (coefficients = coefficients ))
127+ } else if (modelName == " KMeansModel" ) {
128+ modelSize <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
129+ " getKMeansModelSize" , object @ model )
130+ cluster <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
131+ " getKMeansCluster" , object @ model , " classes" )
132+ k <- unlist(modelSize )[1 ]
133+ size <- unlist(modelSize )[- 1 ]
134+ coefficients <- t(matrix (coefficients , ncol = k ))
135+ colnames(coefficients ) <- unlist(features )
136+ rownames(coefficients ) <- 1 : k
137+ return (list (coefficients = coefficients , size = size , cluster = dataFrame(cluster )))
138+ } else {
139+ stop(paste(" Unsupported model" , modelName , sep = " " ))
140+ }
141+ })
142+
143+ # ' Fit a k-means model
144+ # '
145+ # ' Fit a k-means model, similarly to R's kmeans().
146+ # '
147+ # ' @param x DataFrame for training
148+ # ' @param centers Number of centers
149+ # ' @param iter.max Maximum iteration number
150+ # ' @param algorithm Algorithm choosen to fit the model
151+ # ' @return A fitted k-means model
152+ # ' @rdname kmeans
153+ # ' @export
154+ # ' @examples
155+ # '\dontrun{
156+ # ' model <- kmeans(x, centers = 2, algorithm="random")
157+ # '}
158+ setMethod ("kmeans ", signature(x = "DataFrame"),
159+ function (x , centers , iter.max = 10 , algorithm = c(" random" , " k-means||" )) {
160+ columnNames <- as.array(colnames(x ))
161+ algorithm <- match.arg(algorithm )
162+ model <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" , " fitKMeans" , x @ sdf ,
163+ algorithm , iter.max , centers , columnNames )
164+ return (new(" PipelineModel" , model = model ))
165+ })
166+
167+ # ' Get fitted result from a model
168+ # '
169+ # ' Get fitted result from a model, similarly to R's fitted().
170+ # '
171+ # ' @param object A fitted MLlib model
172+ # ' @return DataFrame containing fitted values
173+ # ' @rdname fitted
174+ # ' @export
175+ # ' @examples
176+ # '\dontrun{
177+ # ' model <- kmeans(trainingData, 2)
178+ # ' fitted.model <- fitted(model)
179+ # ' showDF(fitted.model)
180+ # '}
181+ setMethod ("fitted ", signature(object = "PipelineModel"),
182+ function (object , method = c(" centers" , " classes" ), ... ) {
183+ modelName <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
184+ " getModelName" , object @ model )
185+
186+ if (modelName == " KMeansModel" ) {
187+ method <- match.arg(method )
188+ fittedResult <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
189+ " getKMeansCluster" , object @ model , method )
190+ return (dataFrame(fittedResult ))
191+ } else {
192+ stop(paste(" Unsupported model" , modelName , sep = " " ))
127193 }
128194 })
0 commit comments