2222# ' @export
2323setClass ("PipelineModel ", representation(model = "jobj"))
2424
25+ # ' @tile S4 class that represents a NaiveBayesModel
26+ # ' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
27+ # ' @export
28+ setClass ("NaiveBayesModel ", representation(jobj = "jobj"))
29+
2530# ' Fits a generalized linear model
2631# '
2732# ' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -61,7 +66,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
6166 return (new(" PipelineModel" , model = model ))
6267 })
6368
64- # ' Make predictions from a model
69+ # ' Make predictions from a amodel
6570# '
6671# ' Makes predictions from a model produced by glm(), similarly to R's predict().
6772# '
@@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
8186 return (dataFrame(callJMethod(object @ model , " transform" , newData @ sdf )))
8287 })
8388
89+ # ' Make predictions from a naive Bayes model
90+ # '
91+ # ' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
92+ # '
93+ # ' @param object A fitted naive Bayes model
94+ # ' @param newData DataFrame for testing
95+ # ' @return DataFrame containing predicted labels in a column named "prediction"
96+ # ' @rdname predict
97+ # ' @export
98+ # ' @examples
99+ # ' \dontrun{
100+ # ' model <- naiveBayes(y ~ x, trainingData)
101+ # ' predicted <- predict(model, testData)
102+ # ' showDF(predicted)
103+ # '}
104+ setMethod ("predict ", signature(object = "NaiveBayesModel"),
105+ function (object , newData ) {
106+ return (dataFrame(callJMethod(object @ jobj , " transform" , newData @ sdf )))
107+ })
108+
84109# ' Get the summary of a model
85110# '
86111# ' Returns the summary of a model produced by glm(), similarly to R's summary().
@@ -135,24 +160,40 @@ setMethod("summary", signature(object = "PipelineModel"),
135160 colnames(coefficients ) <- unlist(features )
136161 rownames(coefficients ) <- 1 : k
137162 return (list (coefficients = coefficients , size = size , cluster = dataFrame(cluster )))
138- } else if (modelName == " NaiveBayesModel" ) {
139- labels <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
140- " getNaiveBayesLabels" , object @ model )
141- pi <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
142- " getNaiveBayesPi" , object @ model )
143- theta <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
144- " getNaiveBayesTheta" , object @ model )
145- pi <- t(as.matrix(unlist(pi )))
146- colnames(pi ) <- unlist(labels )
147- theta <- matrix (theta , nrow = length(labels ))
148- rownames(theta ) <- unlist(labels )
149- colnames(theta ) <- unlist(features )
150- return (list (pi = pi , theta = theta ))
151163 } else {
152164 stop(paste(" Unsupported model" , modelName , sep = " " ))
153165 }
154166 })
155167
168+ # ' Get the summary of a naive Bayes model
169+ # '
170+ # ' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
171+ # '
172+ # ' @param object A fitted MLlib model
173+ # ' @return a list containing 'apriori', the label distribution, and 'tables', conditional
174+ # probabilities given the target label
175+ # ' @rdname summary
176+ # ' @export
177+ # ' @examples
178+ # ' \dontrun{
179+ # ' model <- naiveBayes(y ~ x, trainingData)
180+ # ' summary(model)
181+ # '}
182+ setMethod ("summary ", signature(object = "NaiveBayesModel"),
183+ function (object , ... ) {
184+ jobj <- object @ jobj
185+ features <- callJMethod(jobj , " features" )
186+ labels <- callJMethod(jobj , " labels" )
187+ apriori <- callJMethod(jobj , " apriori" )
188+ apriori <- t(as.matrix(unlist(apriori )))
189+ colnames(apriori ) <- unlist(labels )
190+ tables <- callJMethod(jobj , " tables" )
191+ tables <- matrix (tables , nrow = length(labels ))
192+ rownames(tables ) <- unlist(labels )
193+ colnames(tables ) <- unlist(features )
194+ return (list (apriori = apriori , tables = tables ))
195+ })
196+
156197# ' Fit a k-means model
157198# '
158199# ' Fit a k-means model, similarly to R's kmeans().
@@ -206,34 +247,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
206247 }
207248 })
208249
209- # ' Fit a naive Bayes model
250+ # ' Fit a Bernoulli naive Bayes model
210251# '
211- # ' Fit a naive Bayes model, similarly to R's naiveBayes() except for omitting two arguments 'subset'
212- # ' and 'na.action'. Users can use 'subset' function and 'fillna' or 'na.omit' function of DataFrame,
213- # ' respectively, to preprocess their DataFrame. We use na.omit in this interface to remove rows with
214- # ' NA values.
252+ # ' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
253+ # ' categorical features are supported. The input should be a DataFrame of observations instead of a
254+ # ' contingency table.
215255# '
216256# ' @param object A symbolic description of the model to be fitted. Currently only a few formula
217- # ' operators are supported, including '~', '.', ':', '+', and '-'.
257+ # ' operators are supported, including '~', '.', ':', '+', and '-'.
218258# ' @param data DataFrame for training
219- # ' @param lambda Smoothing parameter
220- # ' @param modelType Either 'multinomial' or 'bernoulli'. Default "multinomial".
221- # ' @return A fitted naive Bayes model.
259+ # ' @param laplace Smoothing parameter
260+ # ' @return a fitted naive Bayes model
222261# ' @rdname naiveBayes
262+ # ' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
223263# ' @export
224264# ' @examples
225265# ' \dontrun{
226- # ' sc <- sparkR.init()
227- # ' sqlContext <- sparkRSQL.init(sc)
228266# ' df <- createDataFrame(sqlContext, infert)
229- # ' model <- naiveBayes(education ~ ., df, lambda = 1, modelType = "multinomial" )
267+ # ' model <- naiveBayes(education ~ ., df, laplace = 0 )
230268# '}
231269setMethod ("naiveBayes ", signature(formula = "formula", data = "DataFrame"),
232- function (formula , data , lambda = 1 , modelType = c( " multinomial " , " bernoulli " ) , ... ) {
270+ function (formula , data , laplace = 0 , ... ) {
233271 data <- na.omit(data )
234272 formula <- paste(deparse(formula ), collapse = " " )
235- modelType <- match.arg(modelType )
236- model <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" , " fitNaiveBayes" ,
237- formula , data @ sdf , lambda , modelType )
238- return (new(" PipelineModel" , model = model ))
273+ jobj <- callJStatic(" org.apache.spark.ml.r.NaiveBayesWrapper" , " fit" ,
274+ formula , data @ sdf , laplace )
275+ return (new(" NaiveBayesModel" , jobj = jobj ))
239276 })
0 commit comments