@@ -25,29 +25,13 @@ import org.json4s.jackson.JsonMethods._
2525import org .json4s .{DefaultFormats , JValue }
2626
2727import org .apache .spark .{Logging , SparkContext , SparkException }
28- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
2928import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
3029import org .apache .spark .mllib .regression .LabeledPoint
3130import org .apache .spark .mllib .util .{Loader , Saveable }
3231import org .apache .spark .rdd .RDD
3332import org .apache .spark .sql .{DataFrame , SQLContext }
3433
3534
36- /**
37- * Model types supported in Naive Bayes:
38- * multinomial and Bernoulli currently supported
39- */
40- object NaiveBayesModels extends Enumeration {
41- type NaiveBayesModels = Value
42- val Multinomial, Bernoulli = Value
43-
44- implicit def toString (model : NaiveBayesModels ): String = {
45- model.toString
46- }
47- }
48-
49-
50-
5135/**
5236 * Model for Naive Bayes Classifiers.
5337 *
@@ -62,20 +46,21 @@ class NaiveBayesModel private[mllib] (
6246 val labels : Array [Double ],
6347 val pi : Array [Double ],
6448 val theta : Array [Array [Double ]],
65- val modelType : NaiveBayesModels ) extends ClassificationModel with Serializable with Saveable {
49+ val modelType : NaiveBayes .ModelType )
50+ extends ClassificationModel with Serializable with Saveable {
6651
6752 def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
68- this (labels, pi, theta, NaiveBayesModels .Multinomial )
53+ this (labels, pi, theta, NaiveBayes .Multinomial )
6954
7055 private val brzPi = new BDV [Double ](pi)
7156 private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
7257
7358 // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
74- // precomputing log(1.0 - exp(theta)) and its sum for linear algebra application
59+ // this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
7560 // of this condition in predict function
7661 private val (brzNegTheta, brzNegThetaSum) = modelType match {
77- case NaiveBayesModels .Multinomial => (None , None )
78- case NaiveBayesModels .Bernoulli =>
62+ case NaiveBayes .Multinomial => (None , None )
63+ case NaiveBayes .Bernoulli =>
7964 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
8065 (Option (negTheta), Option (brzSum(brzNegTheta, Axis ._1)))
8166 }
@@ -90,16 +75,16 @@ class NaiveBayesModel private[mllib] (
9075
9176 override def predict (testData : Vector ): Double = {
9277 modelType match {
93- case NaiveBayesModels .Multinomial =>
78+ case NaiveBayes .Multinomial =>
9479 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
95- case NaiveBayesModels .Bernoulli =>
80+ case NaiveBayes .Bernoulli =>
9681 labels (brzArgmax (brzPi +
9782 (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
9883 }
9984 }
10085
10186 override def save (sc : SparkContext , path : String ): Unit = {
102- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
87+ val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
10388 NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
10489 }
10590
@@ -152,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
152137 val labels = data.getAs[Seq [Double ]](0 ).toArray
153138 val pi = data.getAs[Seq [Double ]](1 ).toArray
154139 val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
155- val modelType : NaiveBayesModels = NaiveBayesModels .withName (data.getAs[ String ] (3 ))
140+ val modelType = NaiveBayes . ModelType .fromString (data.getString (3 ))
156141 new NaiveBayesModel (labels, pi, theta, modelType)
157142 }
158143 }
159144
160145 override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
161- def getModelType (metadata : JValue ): NaiveBayesModels = {
146+ def getModelType (metadata : JValue ): NaiveBayes . ModelType = {
162147 implicit val formats = DefaultFormats
163- NaiveBayesModels .withName ((metadata \ " modelType" ).extract[String ])
148+ NaiveBayes . ModelType .fromString ((metadata \ " modelType" ).extract[String ])
164149 }
165150 val (loadedClassName, version, metadata) = loadMetadata(sc, path)
166151 val classNameV1_0 = SaveLoadV1_0 .thisClassName
@@ -196,12 +181,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
196181 * document classification. By making every vector a 0-1 vector, it can also be used as
197182 * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
198183 */
199- class NaiveBayes private (private var lambda : Double ,
200- private var modelType : NaiveBayesModels ) extends Serializable with Logging {
201184
202- def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
185+ class NaiveBayes private (
186+ private var lambda : Double ,
187+ private var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
188+
189+ def this (lambda : Double ) = this (lambda, NaiveBayes .Multinomial )
203190
204- def this () = this (1.0 , NaiveBayesModels .Multinomial )
191+ def this () = this (1.0 , NaiveBayes .Multinomial )
205192
206193 /** Set the smoothing parameter. Default: 1.0. */
207194 def setLambda (lambda : Double ): NaiveBayes = {
@@ -210,7 +197,7 @@ class NaiveBayes private (private var lambda: Double,
210197 }
211198
212199 /** Set the model type. Default: Multinomial. */
213- def setModelType (model : NaiveBayesModels ): NaiveBayes = {
200+ def setModelType (model : NaiveBayes . ModelType ): NaiveBayes = {
214201 this .modelType = model
215202 this
216203 }
@@ -267,8 +254,8 @@ class NaiveBayes private (private var lambda: Double,
267254 labels(i) = label
268255 pi(i) = math.log(n + lambda) - piLogDenom
269256 val thetaLogDenom = modelType match {
270- case NaiveBayesModels .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
271- case NaiveBayesModels .Bernoulli => math.log(n + 2.0 * lambda)
257+ case NaiveBayes .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
258+ case NaiveBayes .Bernoulli => math.log(n + 2.0 * lambda)
272259 }
273260 var j = 0
274261 while (j < numFeatures) {
@@ -337,6 +324,37 @@ object NaiveBayes {
337324 * Multinomial or Bernoulli
338325 */
339326 def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
340- new NaiveBayes (lambda, NaiveBayesModels .withName(modelType)).run(input)
327+ new NaiveBayes (lambda, Multinomial ).run(input)
328+ }
329+
330+
331+ /**
332+ * Model types supported in Naive Bayes:
333+ * multinomial and Bernoulli currently supported
334+ */
335+ sealed abstract class ModelType
336+
337+ object MODELTYPE {
338+ final val MULTINOMIAL_STRING = " multinomial"
339+ final val BERNOULLI_STRING = " bernoulli"
340+
341+ def fromString (modelType : String ): ModelType = modelType match {
342+ case MULTINOMIAL_STRING => Multinomial
343+ case BERNOULLI_STRING => Bernoulli
344+ case _ =>
345+ throw new IllegalArgumentException (s " Cannot recognize NaiveBayes ModelType: $modelType" )
346+ }
347+ }
348+
349+ final val ModelType = MODELTYPE
350+
351+ final val Multinomial : ModelType = new ModelType {
352+ override def toString : String = ModelType .MULTINOMIAL_STRING
353+ }
354+
355+ final val Bernoulli : ModelType = new ModelType {
356+ override def toString : String = ModelType .BERNOULLI_STRING
341357 }
358+
342359}
360+
0 commit comments