Skip to content

Commit 49f36f3

Browse files
committed
refactor with NaiveBayesWrapper
1 parent 3d291de commit 49f36f3

File tree

6 files changed

+211
-155
lines changed

6 files changed

+211
-155
lines changed

R/pkg/R/mllib.R

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
#' @export
2323
setClass("PipelineModel", representation(model = "jobj"))
2424

25+
#' @tile S4 class that represents a NaiveBayesModel
26+
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
27+
#' @export
28+
setClass("NaiveBayesModel", representation(jobj = "jobj"))
29+
2530
#' Fits a generalized linear model
2631
#'
2732
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -61,7 +66,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
6166
return(new("PipelineModel", model = model))
6267
})
6368

64-
#' Make predictions from a model
69+
#' Make predictions from a amodel
6570
#'
6671
#' Makes predictions from a model produced by glm(), similarly to R's predict().
6772
#'
@@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
8186
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
8287
})
8388

89+
#' Make predictions from a naive Bayes model
90+
#'
91+
#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
92+
#'
93+
#' @param object A fitted naive Bayes model
94+
#' @param newData DataFrame for testing
95+
#' @return DataFrame containing predicted labels in a column named "prediction"
96+
#' @rdname predict
97+
#' @export
98+
#' @examples
99+
#' \dontrun{
100+
#' model <- naiveBayes(y ~ x, trainingData)
101+
#' predicted <- predict(model, testData)
102+
#' showDF(predicted)
103+
#'}
104+
setMethod("predict", signature(object = "NaiveBayesModel"),
105+
function(object, newData) {
106+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
107+
})
108+
84109
#' Get the summary of a model
85110
#'
86111
#' Returns the summary of a model produced by glm(), similarly to R's summary().
@@ -135,24 +160,40 @@ setMethod("summary", signature(object = "PipelineModel"),
135160
colnames(coefficients) <- unlist(features)
136161
rownames(coefficients) <- 1:k
137162
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
138-
} else if (modelName == "NaiveBayesModel") {
139-
labels <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
140-
"getNaiveBayesLabels", object@model)
141-
pi <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
142-
"getNaiveBayesPi", object@model)
143-
theta <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
144-
"getNaiveBayesTheta", object@model)
145-
pi <- t(as.matrix(unlist(pi)))
146-
colnames(pi) <- unlist(labels)
147-
theta <- matrix(theta, nrow = length(labels))
148-
rownames(theta) <- unlist(labels)
149-
colnames(theta) <- unlist(features)
150-
return(list(pi = pi, theta = theta))
151163
} else {
152164
stop(paste("Unsupported model", modelName, sep = " "))
153165
}
154166
})
155167

168+
#' Get the summary of a naive Bayes model
169+
#'
170+
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
171+
#'
172+
#' @param object A fitted MLlib model
173+
#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
174+
# probabilities given the target label
175+
#' @rdname summary
176+
#' @export
177+
#' @examples
178+
#' \dontrun{
179+
#' model <- naiveBayes(y ~ x, trainingData)
180+
#' summary(model)
181+
#'}
182+
setMethod("summary", signature(object = "NaiveBayesModel"),
183+
function(object, ...) {
184+
jobj <- object@jobj
185+
features <- callJMethod(jobj, "features")
186+
labels <- callJMethod(jobj, "labels")
187+
apriori <- callJMethod(jobj, "apriori")
188+
apriori <- t(as.matrix(unlist(apriori)))
189+
colnames(apriori) <- unlist(labels)
190+
tables <- callJMethod(jobj, "tables")
191+
tables <- matrix(tables, nrow = length(labels))
192+
rownames(tables) <- unlist(labels)
193+
colnames(tables) <- unlist(features)
194+
return(list(apriori = apriori, tables = tables))
195+
})
196+
156197
#' Fit a k-means model
157198
#'
158199
#' Fit a k-means model, similarly to R's kmeans().
@@ -206,34 +247,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
206247
}
207248
})
208249

209-
#' Fit a naive Bayes model
250+
#' Fit a Bernoulli naive Bayes model
210251
#'
211-
#' Fit a naive Bayes model, similarly to R's naiveBayes() except for omitting two arguments 'subset'
212-
#' and 'na.action'. Users can use 'subset' function and 'fillna' or 'na.omit' function of DataFrame,
213-
#' respectively, to preprocess their DataFrame. We use na.omit in this interface to remove rows with
214-
#' NA values.
252+
#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
253+
#' categorical features are supported. The input should be a DataFrame of observations instead of a
254+
#' contingency table.
215255
#'
216256
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
217-
#' operators are supported, including '~', '.', ':', '+', and '-'.
257+
#' operators are supported, including '~', '.', ':', '+', and '-'.
218258
#' @param data DataFrame for training
219-
#' @param lambda Smoothing parameter
220-
#' @param modelType Either 'multinomial' or 'bernoulli'. Default "multinomial".
221-
#' @return A fitted naive Bayes model.
259+
#' @param laplace Smoothing parameter
260+
#' @return a fitted naive Bayes model
222261
#' @rdname naiveBayes
262+
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
223263
#' @export
224264
#' @examples
225265
#' \dontrun{
226-
#' sc <- sparkR.init()
227-
#' sqlContext <- sparkRSQL.init(sc)
228266
#' df <- createDataFrame(sqlContext, infert)
229-
#' model <- naiveBayes(education ~ ., df, lambda = 1, modelType = "multinomial")
267+
#' model <- naiveBayes(education ~ ., df, laplace = 0)
230268
#'}
231269
setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
232-
function(formula, data, lambda = 1, modelType = c("multinomial", "bernoulli"), ...) {
270+
function(formula, data, laplace = 0, ...) {
233271
data <- na.omit(data)
234272
formula <- paste(deparse(formula), collapse = "")
235-
modelType <- match.arg(modelType)
236-
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitNaiveBayes",
237-
formula, data@sdf, lambda, modelType)
238-
return(new("PipelineModel", model = model))
273+
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
274+
formula, data@sdf, laplace)
275+
return(new("NaiveBayesModel", jobj = jobj))
239276
})

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,25 +143,60 @@ test_that("kmeans", {
143143
})
144144

145145
test_that("naiveBayes", {
146-
training <- suppressWarnings(createDataFrame(sqlContext, infert))
147-
148-
model <- naiveBayes(education ~ ., data = training, lambda = 1, modelType = "multinomial")
149-
sample <- take(select(predict(model, training), "rawLabelsPrediction"), 1)
150-
expect_equal(typeof(sample$rawLabelsPrediction), "character")
151-
expect_equal(sample$rawLabelsPrediction, "0-5yrs")
152-
153-
# Test summary works on naiveBayes
154-
summary.model <- summary(model)
155-
expect_equal(length(summary.model$pi), 3)
156-
expect_equal(sum(summary.model$pi), 1)
157-
l1 <- summary.model$theta[1, ]
158-
l2 <- summary.model$theta[2, ]
159-
expect_equal(sum(unlist(l1)), 1)
160-
expect_equal(sum(unlist(l2)), 1)
146+
# R code to reproduce the result.
147+
# We do not support instance weights yet. So we ignore the frequencies.
148+
#
149+
# library(e1071)
150+
# t <- as.data.frame(Titanic)
151+
# t1 <- t[t$Freq > 0, -5]
152+
# m <- naiveBayes(Survived ~ ., data = t1)
153+
# m
154+
# predict(m, t1)
155+
#
156+
# -- output of 'm'
157+
#
158+
# A-priori probabilities:
159+
# Y
160+
# No Yes
161+
# 0.4166667 0.5833333
162+
#
163+
# Conditional probabilities:
164+
# Class
165+
# Y 1st 2nd 3rd Crew
166+
# No 0.2000000 0.2000000 0.4000000 0.2000000
167+
# Yes 0.2857143 0.2857143 0.2857143 0.1428571
168+
#
169+
# Sex
170+
# Y Male Female
171+
# No 0.5 0.5
172+
# Yes 0.5 0.5
173+
#
174+
# Age
175+
# Y Child Adult
176+
# No 0.2000000 0.8000000
177+
# Yes 0.4285714 0.5714286
178+
#
179+
# -- output of 'predict(m, t1)'
180+
#
181+
# Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
182+
#
183+
184+
t <- as.data.frame(Titanic)
185+
t1 <- t[t$Freq > 0, -5]
186+
df <- suppressWarnings(createDataFrame(sqlContext, t1))
187+
m <- naiveBayes(Survived ~ ., data = df)
188+
s <- summary(m)
189+
expect_equal(s$apriori[1, "Yes"], 0.5833333, tolerance = 1e-6)
190+
expect_equal(sum(s$apriori), 1)
191+
expect_equal(s$tables["Yes", "Age_Adult"], 0.5714286, tolerance = 1e-6)
192+
p <- collect(select(predict(m, df), "prediction"))
193+
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
194+
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
195+
"Yes", "Yes", "No", "No"))
161196

162197
# Test e1071::naiveBayes
163198
if (requireNamespace("e1071", quietly = TRUE)) {
164-
expect_that(m <- e1071::naiveBayes(education ~ ., data = infert), not(throws_error()))
165-
expect_equal(as.character(predict(m, infert[1, ])), "0-5yrs")
199+
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
200+
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
166201
}
167202
})

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.apache.hadoop.fs.Path
2222
import org.apache.spark.SparkException
2323
import org.apache.spark.annotation.{Experimental, Since}
2424
import org.apache.spark.ml.PredictorParams
25-
import org.apache.spark.ml.attribute.AttributeGroup
2625
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
2726
import org.apache.spark.ml.util._
2827
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
@@ -105,12 +104,7 @@ class NaiveBayes @Since("1.5.0") (
105104
override protected def train(dataset: DataFrame): NaiveBayesModel = {
106105
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
107106
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
108-
val nbModel = copyValues(NaiveBayesModel.fromOld(oldModel, this))
109-
val attr = AttributeGroup.fromStructField(dataset.schema($(featuresCol))).attributes
110-
if (attr.isDefined) {
111-
nbModel.setFeatureNames(attr.get.map(_.name.getOrElse("NA")))
112-
}
113-
nbModel
107+
NaiveBayesModel.fromOld(oldModel, this)
114108
}
115109

116110
@Since("1.5.0")
@@ -233,21 +227,6 @@ class NaiveBayesModel private[ml] (
233227

234228
@Since("1.6.0")
235229
override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
236-
237-
private var featureNames: Option[Array[String]] = None
238-
239-
private[classification] def setFeatureNames(names: Array[String]): this.type = {
240-
this.featureNames = Some(names)
241-
this
242-
}
243-
244-
private[ml] def getFeatureNames: Array[String] = featureNames match {
245-
case Some(names) => names
246-
case None =>
247-
throw new SparkException(
248-
s"No training result available for the ${this.getClass.getSimpleName}",
249-
new NullPointerException())
250-
}
251230
}
252231

253232
@Since("1.6.0")
@@ -258,6 +237,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
258237
oldModel: OldNaiveBayesModel,
259238
parent: NaiveBayes): NaiveBayesModel = {
260239
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
240+
val labels = Vectors.dense(oldModel.labels)
261241
val pi = Vectors.dense(oldModel.pi)
262242
val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
263243
oldModel.theta.flatten, true)

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,6 @@ class RFormulaModel private[feature](
257257
"Label column already exists and is not of type DoubleType.")
258258
}
259259

260-
/**
261-
* Get the original array of labels if exists.
262-
*/
263-
private[ml] def getOriginalLabels: Option[Array[String]] = {
264-
// According to the sequences of transformers in RFormula, if the last stage is a
265-
// StringIndexerModel, then we can extract the original labels from it.
266-
pipelineModel.stages.last match {
267-
case m: StringIndexerModel => Some(m.labels)
268-
case _ => None
269-
}
270-
}
271-
272260
@Since("2.0.0")
273261
override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
274262
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.spark.ml.{Pipeline, PipelineModel}
21+
import org.apache.spark.ml.attribute.{AttributeGroup, Attribute, NominalAttribute}
22+
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
23+
import org.apache.spark.ml.feature.{IndexToString, RFormula}
24+
import org.apache.spark.sql.DataFrame
25+
26+
private[r] class NaiveBayesWrapper private (
27+
pipeline: PipelineModel,
28+
val labels: Array[String],
29+
val features: Array[String]) {
30+
31+
import NaiveBayesWrapper._
32+
33+
private val naiveBayesModel: NaiveBayesModel = pipeline.stages(1).asInstanceOf[NaiveBayesModel]
34+
35+
lazy val apriori: Array[Double] = naiveBayesModel.pi.toArray.map(math.exp)
36+
37+
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
38+
39+
def transform(dataset: DataFrame): DataFrame = {
40+
pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
41+
}
42+
}
43+
44+
private[r] object NaiveBayesWrapper {
45+
46+
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
47+
val PREDICTED_LABEL_COL = "prediction"
48+
49+
def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = {
50+
val rFormula = new RFormula()
51+
.setFormula(formula)
52+
.fit(data)
53+
// get labels and feature names from output schema
54+
val schema = rFormula.transform(data).schema
55+
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
56+
.asInstanceOf[NominalAttribute]
57+
val labels = labelAttr.values.get
58+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
59+
.attributes.get
60+
val features = featureAttrs.map(_.name.get)
61+
// assemble and fit the pipeline
62+
val naiveBayes = new NaiveBayes()
63+
.setSmoothing(laplace)
64+
.setModelType("bernoulli")
65+
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
66+
val idxToStr = new IndexToString()
67+
.setInputCol(PREDICTED_LABEL_INDEX_COL)
68+
.setOutputCol(PREDICTED_LABEL_COL)
69+
.setLabels(labels)
70+
val pipeline = new Pipeline()
71+
.setStages(Array(rFormula, naiveBayes, idxToStr))
72+
.fit(data)
73+
new NaiveBayesWrapper(pipeline, labels, features)
74+
}
75+
}

0 commit comments

Comments
 (0)