Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] (
}

override def predict(testData: Vector): Double = {
val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
labels (brzArgmax (brzPi + brzTheta * brzData) )
case "Bernoulli" =>
if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
}
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
Expand Down Expand Up @@ -293,12 +298,29 @@ class NaiveBayes private (
}
}

val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
values
case DenseVector(values) =>
values
}
if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
}
}

// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
requireNonnegativeValues(v)
if (modelType == "Bernoulli") {
requireZeroOneBernoulliValues(v)
} else {
requireNonnegativeValues(v)
}
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}

test("detect non zero or one values in Bernoulli") {
val badTrain = Seq(
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(1.0, Vectors.dense(0.0)))

intercept[SparkException] {
NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
}

val okTrain = Seq(
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(1.0, Vectors.dense(1.0))
)

val badPredict = Seq(
Vectors.dense(1.0),
Vectors.dense(2.0),
Vectors.dense(1.0),
Vectors.dense(0.0))

val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
intercept[SparkException] {
model.predict(sc.makeRDD(badPredict, 2)).collect()
}
}

test("model save/load: 2.0 to 2.0") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
Expand Down