@@ -25,16 +25,17 @@ import org.json4s.jackson.JsonMethods._
2525import org .json4s .{DefaultFormats , JValue }
2626
2727import org .apache .spark .{Logging , SparkContext , SparkException }
28+ import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
2829import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
2930import org .apache .spark .mllib .regression .LabeledPoint
30- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
3131import org .apache .spark .mllib .util .{Loader , Saveable }
3232import org .apache .spark .rdd .RDD
3333import org .apache .spark .sql .{DataFrame , SQLContext }
3434
3535
3636/**
37- *
37+ * Model types supported in Naive Bayes:
38+ * multinomial and Bernoulli currently supported
3839 */
3940object NaiveBayesModels extends Enumeration {
4041 type NaiveBayesModels = Value
@@ -45,6 +46,8 @@ object NaiveBayesModels extends Enumeration {
4546 }
4647}
4748
49+
50+
4851/**
4952 * Model for Naive Bayes Classifiers.
5053 *
@@ -55,7 +58,6 @@ object NaiveBayesModels extends Enumeration {
5558 * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
5659 * Multinomial or Bernoulli
5760 */
58-
5961class NaiveBayesModel private [mllib] (
6062 val labels : Array [Double ],
6163 val pi : Array [Double ],
@@ -68,11 +70,14 @@ class NaiveBayesModel private[mllib] (
6870 private val brzPi = new BDV [Double ](pi)
6971 private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
7072
71- private val brzNegTheta : Option [BDM [Double ]] = modelType match {
72- case NaiveBayesModels .Multinomial => None
73+ // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
74+ // precomputing log(1.0 - exp(theta)) and its sum for linear algebra application
75+ // of this condition in predict function
76+ private val (brzNegTheta, brzNegThetaSum) = modelType match {
77+ case NaiveBayesModels .Multinomial => (None , None )
7378 case NaiveBayesModels .Bernoulli =>
7479 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
75- Option (negTheta)
80+ ( Option (negTheta), Option (brzSum(brzNegTheta, Axis ._1)) )
7681 }
7782
7883 override def predict (testData : RDD [Vector ]): RDD [Double ] = {
@@ -89,8 +94,7 @@ class NaiveBayesModel private[mllib] (
8994 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
9095 case NaiveBayesModels .Bernoulli =>
9196 labels (brzArgmax (brzPi +
92- (brzTheta - brzNegTheta.get) * testData.toBreeze +
93- brzSum(brzNegTheta.get, Axis ._1)))
97+ (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
9498 }
9599 }
96100
@@ -114,10 +118,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
114118 def thisClassName = " org.apache.spark.mllib.classification.NaiveBayesModel"
115119
116120 /** Model data for model import/export */
117- case class Data (labels : Array [Double ],
118- pi : Array [Double ],
119- theta : Array [Array [Double ]],
120- modelType : String )
121+ case class Data (
122+ labels : Array [Double ],
123+ pi : Array [Double ],
124+ theta : Array [Array [Double ]],
125+ modelType : String )
121126
122127 def save (sc : SparkContext , path : String , data : Data ): Unit = {
123128 val sqlContext = new SQLContext (sc)
@@ -192,7 +197,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
192197 * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
193198 */
194199class NaiveBayes private (private var lambda : Double ,
195- var modelType : NaiveBayesModels ) extends Serializable with Logging {
200+ private var modelType : NaiveBayesModels ) extends Serializable with Logging {
196201
197202 def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
198203
@@ -284,7 +289,7 @@ object NaiveBayes {
284289 /**
285290 * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
286291 *
287- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
292+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
288293 * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
289294 * document classification.
290295 *
@@ -300,7 +305,7 @@ object NaiveBayes {
300305 /**
301306 * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
302307 *
303- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
308+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
304309 * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
305310 * document classification.
306311 *
@@ -316,11 +321,13 @@ object NaiveBayes {
316321 /**
317322 * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
318323 *
319- * This is by default the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle
320- * all kinds of discrete data. For example, by converting documents into TF-IDF vectors,
321- * it can be used for document classification. By making every vector a 0-1 vector and
322- * setting the model type to NaiveBayesModels.Bernoulli, it fits and predicts as
323- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
324+ * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p ]])
325+ * or Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The Multinomial NB can handle
326+ * discrete count data and can be called by setting the model type to "Multinomial".
327+ * For example, it can be used with word counts or TF_IDF vectors of documents.
328+ * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
329+ * 0-1 vector and setting the model type to "Bernoulli", the fits and predicts as
330+ * Bernoulli NB.
324331 *
325332 * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
326333 * vector or a count vector.
0 commit comments