Skip to content

Commit dc65374

Browse files
committed
integrated model type fix
2 parents 7622b0c + b93aaf6 commit dc65374

File tree

1 file changed

+53
-35
lines changed

1 file changed

+53
-35
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,13 @@ import org.json4s.jackson.JsonMethods._
2525
import org.json4s.{DefaultFormats, JValue}
2626

2727
import org.apache.spark.{Logging, SparkContext, SparkException}
28-
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
2928
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
3029
import org.apache.spark.mllib.regression.LabeledPoint
3130
import org.apache.spark.mllib.util.{Loader, Saveable}
3231
import org.apache.spark.rdd.RDD
3332
import org.apache.spark.sql.{DataFrame, SQLContext}
3433

3534

36-
/**
37-
* Model types supported in Naive Bayes:
38-
* multinomial and Bernoulli currently supported
39-
*/
40-
object NaiveBayesModels extends Enumeration {
41-
type NaiveBayesModels = Value
42-
val Multinomial, Bernoulli = Value
43-
44-
implicit def toString(model: NaiveBayesModels): String = {
45-
model.toString
46-
}
47-
}
48-
49-
50-
5135
/**
5236
* Model for Naive Bayes Classifiers.
5337
*
@@ -62,20 +46,21 @@ class NaiveBayesModel private[mllib] (
6246
val labels: Array[Double],
6347
val pi: Array[Double],
6448
val theta: Array[Array[Double]],
65-
val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable {
49+
val modelType: NaiveBayes.ModelType)
50+
extends ClassificationModel with Serializable with Saveable {
6651

6752
def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
68-
this(labels, pi, theta, NaiveBayesModels.Multinomial)
53+
this(labels, pi, theta, NaiveBayes.Multinomial)
6954

7055
private val brzPi = new BDV[Double](pi)
7156
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
7257

7358
//Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
74-
//precomputing log(1.0 - exp(theta)) and its sum for linear algebra application
59+
//this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
7560
//of this condition in predict function
7661
private val (brzNegTheta, brzNegThetaSum) = modelType match {
77-
case NaiveBayesModels.Multinomial => (None, None)
78-
case NaiveBayesModels.Bernoulli =>
62+
case NaiveBayes.Multinomial => (None, None)
63+
case NaiveBayes.Bernoulli =>
7964
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
8065
(Option(negTheta), Option(brzSum(brzNegTheta, Axis._1)))
8166
}
@@ -90,16 +75,16 @@ class NaiveBayesModel private[mllib] (
9075

9176
override def predict(testData: Vector): Double = {
9277
modelType match {
93-
case NaiveBayesModels.Multinomial =>
78+
case NaiveBayes.Multinomial =>
9479
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
95-
case NaiveBayesModels.Bernoulli =>
80+
case NaiveBayes.Bernoulli =>
9681
labels (brzArgmax (brzPi +
9782
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
9883
}
9984
}
10085

10186
override def save(sc: SparkContext, path: String): Unit = {
102-
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
87+
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
10388
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
10489
}
10590

@@ -152,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
152137
val labels = data.getAs[Seq[Double]](0).toArray
153138
val pi = data.getAs[Seq[Double]](1).toArray
154139
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
155-
val modelType: NaiveBayesModels = NaiveBayesModels.withName(data.getAs[String](3))
140+
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
156141
new NaiveBayesModel(labels, pi, theta, modelType)
157142
}
158143
}
159144

160145
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
161-
def getModelType(metadata: JValue): NaiveBayesModels = {
146+
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
162147
implicit val formats = DefaultFormats
163-
NaiveBayesModels.withName((metadata \ "modelType").extract[String])
148+
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
164149
}
165150
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
166151
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -196,12 +181,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
196181
* document classification. By making every vector a 0-1 vector, it can also be used as
197182
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
198183
*/
199-
class NaiveBayes private (private var lambda: Double,
200-
private var modelType: NaiveBayesModels) extends Serializable with Logging {
201184

202-
def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial)
185+
class NaiveBayes private (
186+
private var lambda: Double,
187+
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
188+
189+
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
203190

204-
def this() = this(1.0, NaiveBayesModels.Multinomial)
191+
def this() = this(1.0, NaiveBayes.Multinomial)
205192

206193
/** Set the smoothing parameter. Default: 1.0. */
207194
def setLambda(lambda: Double): NaiveBayes = {
@@ -210,7 +197,7 @@ class NaiveBayes private (private var lambda: Double,
210197
}
211198

212199
/** Set the model type. Default: Multinomial. */
213-
def setModelType(model: NaiveBayesModels): NaiveBayes = {
200+
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
214201
this.modelType = model
215202
this
216203
}
@@ -267,8 +254,8 @@ class NaiveBayes private (private var lambda: Double,
267254
labels(i) = label
268255
pi(i) = math.log(n + lambda) - piLogDenom
269256
val thetaLogDenom = modelType match {
270-
case NaiveBayesModels.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
271-
case NaiveBayesModels.Bernoulli => math.log(n + 2.0 * lambda)
257+
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
258+
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
272259
}
273260
var j = 0
274261
while (j < numFeatures) {
@@ -337,6 +324,37 @@ object NaiveBayes {
337324
* Multinomial or Bernoulli
338325
*/
339326
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
340-
new NaiveBayes(lambda, NaiveBayesModels.withName(modelType)).run(input)
327+
new NaiveBayes(lambda, Multinomial).run(input)
328+
}
329+
330+
331+
/**
332+
* Model types supported in Naive Bayes:
333+
* multinomial and Bernoulli currently supported
334+
*/
335+
sealed abstract class ModelType
336+
337+
object MODELTYPE {
338+
final val MULTINOMIAL_STRING = "multinomial"
339+
final val BERNOULLI_STRING = "bernoulli"
340+
341+
def fromString(modelType: String): ModelType = modelType match {
342+
case MULTINOMIAL_STRING => Multinomial
343+
case BERNOULLI_STRING => Bernoulli
344+
case _ =>
345+
throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType")
346+
}
347+
}
348+
349+
final val ModelType = MODELTYPE
350+
351+
final val Multinomial: ModelType = new ModelType {
352+
override def toString: String = ModelType.MULTINOMIAL_STRING
353+
}
354+
355+
final val Bernoulli: ModelType = new ModelType {
356+
override def toString: String = ModelType.BERNOULLI_STRING
341357
}
358+
342359
}
360+

0 commit comments

Comments
 (0)