@@ -27,24 +27,11 @@ import org.json4s.{DefaultFormats, JValue}
2727import org .apache .spark .{Logging , SparkContext , SparkException }
2828import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
2929import org .apache .spark .mllib .regression .LabeledPoint
30- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
3130import org .apache .spark .mllib .util .{Loader , Saveable }
3231import org .apache .spark .rdd .RDD
3332import org .apache .spark .sql .{DataFrame , SQLContext }
3433
3534
36- /**
37- *
38- */
39- object NaiveBayesModels extends Enumeration {
40- type NaiveBayesModels = Value
41- val Multinomial, Bernoulli = Value
42-
43- implicit def toString (model : NaiveBayesModels ): String = {
44- model.toString
45- }
46- }
47-
4835/**
4936 * Model for Naive Bayes Classifiers.
5037 *
@@ -60,17 +47,18 @@ class NaiveBayesModel private[mllib] (
6047 val labels : Array [Double ],
6148 val pi : Array [Double ],
6249 val theta : Array [Array [Double ]],
63- val modelType : NaiveBayesModels ) extends ClassificationModel with Serializable with Saveable {
50+ val modelType : NaiveBayes .ModelType )
51+ extends ClassificationModel with Serializable with Saveable {
6452
6553 def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
66- this (labels, pi, theta, NaiveBayesModels .Multinomial )
54+ this (labels, pi, theta, NaiveBayes .Multinomial )
6755
6856 private val brzPi = new BDV [Double ](pi)
6957 private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
7058
7159 private val brzNegTheta : Option [BDM [Double ]] = modelType match {
72- case NaiveBayesModels .Multinomial => None
73- case NaiveBayesModels .Bernoulli =>
60+ case NaiveBayes .Multinomial => None
61+ case NaiveBayes .Bernoulli =>
7462 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
7563 Option (negTheta)
7664 }
@@ -85,17 +73,17 @@ class NaiveBayesModel private[mllib] (
8573
8674 override def predict (testData : Vector ): Double = {
8775 modelType match {
88- case NaiveBayesModels .Multinomial =>
76+ case NaiveBayes .Multinomial =>
8977 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
90- case NaiveBayesModels .Bernoulli =>
78+ case NaiveBayes .Bernoulli =>
9179 labels (brzArgmax (brzPi +
9280 (brzTheta - brzNegTheta.get) * testData.toBreeze +
9381 brzSum(brzNegTheta.get, Axis ._1)))
9482 }
9583 }
9684
9785 override def save (sc : SparkContext , path : String ): Unit = {
98- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
86+ val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
9987 NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
10088 }
10189
@@ -147,15 +135,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
147135 val labels = data.getAs[Seq [Double ]](0 ).toArray
148136 val pi = data.getAs[Seq [Double ]](1 ).toArray
149137 val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
150- val modelType : NaiveBayesModels = NaiveBayesModels .withName (data.getAs[ String ] (3 ))
138+ val modelType = NaiveBayes . ModelType .fromString (data.getString (3 ))
151139 new NaiveBayesModel (labels, pi, theta, modelType)
152140 }
153141 }
154142
155143 override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
156- def getModelType (metadata : JValue ): NaiveBayesModels = {
144+ def getModelType (metadata : JValue ): NaiveBayes . ModelType = {
157145 implicit val formats = DefaultFormats
158- NaiveBayesModels .withName ((metadata \ " modelType" ).extract[String ])
146+ NaiveBayes . ModelType .fromString ((metadata \ " modelType" ).extract[String ])
159147 }
160148 val (loadedClassName, version, metadata) = loadMetadata(sc, path)
161149 val classNameV1_0 = SaveLoadV1_0 .thisClassName
@@ -191,12 +179,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
191179 * document classification. By making every vector a 0-1 vector, it can also be used as
192180 * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
193181 */
194- class NaiveBayes private (private var lambda : Double ,
195- var modelType : NaiveBayesModels ) extends Serializable with Logging {
182+ class NaiveBayes private (
183+ private var lambda : Double ,
184+ var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
196185
197- def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
186+ def this (lambda : Double ) = this (lambda, NaiveBayes .Multinomial )
198187
199- def this () = this (1.0 , NaiveBayesModels .Multinomial )
188+ def this () = this (1.0 , NaiveBayes .Multinomial )
200189
201190 /** Set the smoothing parameter. Default: 1.0. */
202191 def setLambda (lambda : Double ): NaiveBayes = {
@@ -205,7 +194,7 @@ class NaiveBayes private (private var lambda: Double,
205194 }
206195
207196 /** Set the model type. Default: Multinomial. */
208- def setModelType (model : NaiveBayesModels ): NaiveBayes = {
197+ def setModelType (model : NaiveBayes . ModelType ): NaiveBayes = {
209198 this .modelType = model
210199 this
211200 }
@@ -262,8 +251,8 @@ class NaiveBayes private (private var lambda: Double,
262251 labels(i) = label
263252 pi(i) = math.log(n + lambda) - piLogDenom
264253 val thetaLogDenom = modelType match {
265- case NaiveBayesModels .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
266- case NaiveBayesModels .Bernoulli => math.log(n + 2.0 * lambda)
254+ case NaiveBayes .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
255+ case NaiveBayes .Bernoulli => math.log(n + 2.0 * lambda)
267256 }
268257 var j = 0
269258 while (j < numFeatures) {
@@ -330,6 +319,32 @@ object NaiveBayes {
330319 * Multinomial or Bernoulli
331320 */
332321 def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
333- new NaiveBayes (lambda, NaiveBayesModels .withName(modelType) ).run(input)
322+ new NaiveBayes (lambda, Multinomial ).run(input)
334323 }
324+
325+ sealed abstract class ModelType
326+
327+ object MODELTYPE {
328+ final val MULTINOMIAL_STRING = " multinomial"
329+ final val BERNOULLI_STRING = " bernoulli"
330+
331+ def fromString (modelType : String ): ModelType = modelType match {
332+ case MULTINOMIAL_STRING => Multinomial
333+ case BERNOULLI_STRING => Bernoulli
334+ case _ =>
335+ throw new IllegalArgumentException (s " Cannot recognize NaiveBayes ModelType: $modelType" )
336+ }
337+ }
338+
339+ final val ModelType = MODELTYPE
340+
341+ final val Multinomial : ModelType = new ModelType {
342+ override def toString : String = ModelType .MULTINOMIAL_STRING
343+ }
344+
345+ final val Bernoulli : ModelType = new ModelType {
346+ override def toString : String = ModelType .BERNOULLI_STRING
347+ }
348+
335349}
350+
0 commit comments