Skip to content

Commit 4594761

Browse files
committed
[SPARK-1655][MLLIB] Add option for distributed naive bayes model.
1 parent fd0b32c commit 4594761

File tree

4 files changed

+173
-20
lines changed

4 files changed

+173
-20
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,11 @@ class PythonMLLibAPI extends Serializable {
232232
def trainNaiveBayes(
233233
data: JavaRDD[LabeledPoint],
234234
lambda: Double): java.util.List[java.lang.Object] = {
235-
val model = NaiveBayes.train(data.rdd, lambda)
235+
// val model = NaiveBayes.train(data.rdd, lambda, "local")
236236
val ret = new java.util.LinkedList[java.lang.Object]()
237-
ret.add(Vectors.dense(model.labels))
238-
ret.add(Vectors.dense(model.pi))
239-
ret.add(model.theta)
237+
// ret.add(Vectors.dense(model.labels))
238+
// ret.add(Vectors.dense(model.pi))
239+
// ret.add(model.theta)
240240
ret
241241
}
242242

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 119 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,25 @@ import org.apache.spark.SparkContext._
2424
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
2525
import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.storage.StorageLevel
2728

2829
/**
29-
* Model for Naive Bayes Classifiers.
30+
* Abstract model for a naive bayes classifier.
31+
*/
32+
abstract class NaiveBayesModel extends ClassificationModel with Serializable
33+
34+
/**
35+
* Local model for a naive bayes classifier.
3036
*
3137
* @param labels list of labels
3238
* @param pi log of class priors, whose dimension is C, number of labels
3339
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
3440
* where D is number of features
3541
*/
36-
class NaiveBayesModel private[mllib] (
42+
private class LocalNaiveBayesModel(
3743
val labels: Array[Double],
3844
val pi: Array[Double],
39-
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
45+
val theta: Array[Array[Double]]) extends NaiveBayesModel {
4046

4147
private val brzPi = new BDV[Double](pi)
4248
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
@@ -67,6 +73,54 @@ class NaiveBayesModel private[mllib] (
6773
}
6874
}
6975

76+
/**
77+
* One block from a distributed model for a naive bayes classifier. The model is divided into
78+
* blocks, each containing the complete model state for a group of labels.
79+
*
80+
* @param labels array of labels
81+
* @param pi log of class priors, with dimension C, the number of labels in this block
82+
* @param theta log of class conditional probabilities, with dimensions C-by-D,
83+
* where D is the number of features
84+
*/
85+
private case class NBModelBlock(labels: Array[Double], pi: BDV[Double], theta: BDM[Double])
86+
87+
/**
88+
* Distributed model for a naive bayes classifier.
89+
*
90+
* @param modelBlocks RDD of NBModelBlock, comprising the model
91+
*/
92+
private class DistNaiveBayesModel(val modelBlocks: RDD[NBModelBlock]) extends NaiveBayesModel {
93+
94+
override def predict(testData: RDD[Vector]): RDD[Double] = {
95+
// Pair each test data point with all model blocks.
96+
val testXModel = testData.map(_.toBreeze).zipWithIndex().cartesian(modelBlocks)
97+
98+
// Find the maximum a posteriori label for each (test_data_point, model_block) pair.
99+
val testXModelMaxes = testXModel.map { case ((test, i), model) => {
100+
val posterior = model.pi + model.theta * test
101+
val maxIdx = brzArgmax(posterior)
102+
(i, (posterior(maxIdx), model.labels(maxIdx)))
103+
}}
104+
105+
// Find the maximum for each test data point, across all model blocks.
106+
val testMaxes = testXModelMaxes.reduceByKey(Ordering[(Double,Double)].max)
107+
108+
// Reorder based on the original testData index, then project the labels.
109+
testMaxes.sortByKey().map{ case (_, (_, label)) => label }
110+
}
111+
112+
override def predict(testData: Vector): Double = {
113+
val testBreeze = testData.toBreeze
114+
115+
// Find the max a posteriori label for each model block, then the max of these block maxes.
116+
modelBlocks.map( m => {
117+
val posterior = m.pi + m.theta * testBreeze
118+
val maxIdx = brzArgmax(posterior)
119+
(posterior(maxIdx), m.labels(maxIdx))
120+
}).max._2
121+
}
122+
}
123+
70124
/**
71125
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
72126
*
@@ -77,6 +131,8 @@ class NaiveBayesModel private[mllib] (
77131
*/
78132
class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
79133

134+
private var distMode = "local"
135+
80136
def this() = this(1.0)
81137

82138
/** Set the smoothing parameter. Default: 1.0. */
@@ -85,6 +141,12 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
85141
this
86142
}
87143

144+
/** Set the model distribution mode, either "local" or "dist" (for distributed). */
145+
def setDistMode(distMode: String) = {
146+
this.distMode = distMode
147+
this
148+
}
149+
88150
/**
89151
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
90152
*
@@ -103,10 +165,8 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
103165
}
104166
}
105167

106-
// Aggregates term frequencies per label.
107-
// TODO: Calling combineByKey and collect creates two stages, we can implement something
108-
// TODO: similar to reduceByKeyLocally to save one stage.
109-
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
168+
// Sum the document counts and feature frequencies for each label.
169+
val labelAggregates = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
110170
createCombiner = (v: Vector) => {
111171
requireNonnegativeValues(v)
112172
(1L, v.toBreeze.toDenseVector)
@@ -117,7 +177,20 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
117177
},
118178
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
119179
(c1._1 + c2._1, c1._2 += c2._2)
120-
).collect()
180+
)
181+
182+
distMode match {
183+
case "local" => trainLocalModel(labelAggregates)
184+
case "dist" => trainDistModel(labelAggregates)
185+
case _ =>
186+
throw new SparkException(s"Naive Bayes requires a valid distMode but found $distMode.")
187+
}
188+
}
189+
190+
private def trainLocalModel(labelAggregates: RDD[(Double, (Long, BDV[Double]))]) = {
191+
// TODO: Calling combineByKey and collect creates two stages, we can implement something
192+
// TODO: similar to reduceByKeyLocally to save one stage.
193+
val aggregated = labelAggregates.collect()
121194
val numLabels = aggregated.length
122195
var numDocuments = 0L
123196
aggregated.foreach { case (_, (n, _)) =>
@@ -141,7 +214,41 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
141214
i += 1
142215
}
143216

144-
new NaiveBayesModel(labels, pi, theta)
217+
new LocalNaiveBayesModel(labels, pi, theta)
218+
}
219+
220+
private def trainDistModel(labelAggregates: RDD[(Double, (Long, BDV[Double]))]) = {
221+
case class LabelAggregate(label: Double, numDocuments: Long, sumFeatures: BDV[Double])
222+
val aggregated = labelAggregates.map(x => LabelAggregate(x._1, x._2._1, x._2._2))
223+
224+
// Compute the model's prior (pi) vector and conditional (theta) matrix for each batch of
225+
// labels.
226+
// NOTE In contrast to the local trainer, the piLogDenom normalization term is omitted here.
227+
// Computing this term requires an additional aggregation on 'aggregated', and because the
228+
// term is an additive constant it does not affect maximum a posteriori model prediction.
229+
val modelBlocks = aggregated.mapPartitions(p => p.grouped(100).map { batch =>
230+
val numFeatures = batch.head.sumFeatures.length
231+
val pi = batch.map(l => math.log(l.numDocuments + lambda))
232+
233+
// Assemble values of the theta matrix in row major order.
234+
val theta = new Array[Double](batch.length * numFeatures)
235+
batch.flatMap( l => {
236+
val thetaLogDenom = math.log(brzSum(l.sumFeatures) + numFeatures * lambda)
237+
l.sumFeatures.iterator.map(f => math.log(f._2 + lambda) - thetaLogDenom)
238+
}).copyToArray(theta)
239+
240+
NBModelBlock(labels = batch.map(_.label).toArray,
241+
pi = new BDV[Double](pi.toArray),
242+
theta = new BDM[Double](batch.length, numFeatures, theta,
243+
offset=0, majorStride=numFeatures, isTranspose=true))
244+
})
245+
246+
// Materialize and persist the model, check that it is nonempty.
247+
if (modelBlocks.persist(StorageLevel.MEMORY_AND_DISK).count() == 0) {
248+
throw new SparkException("Naive Bayes requires a nonempty training RDD.")
249+
}
250+
251+
new DistNaiveBayesModel(modelBlocks)
145252
}
146253
}
147254

@@ -177,8 +284,9 @@ object NaiveBayes {
177284
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
178285
* vector or a count vector.
179286
* @param lambda The smoothing parameter
287+
* @param distMode The model distribution mode, either "local" or "dist" (for distributed)
180288
*/
181-
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
182-
new NaiveBayes(lambda).run(input)
289+
def train(input: RDD[LabeledPoint], lambda: Double, distMode: String): NaiveBayesModel = {
290+
new NaiveBayes(lambda).setDistMode(distMode).run(input)
183291
}
184292
}

mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void runUsingStaticMethods() {
8484
int numAccurate1 = validatePrediction(POINTS, model1);
8585
Assert.assertEquals(POINTS.size(), numAccurate1);
8686

87-
NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
87+
NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5, "local");
8888
int numAccurate2 = validatePrediction(POINTS, model2);
8989
Assert.assertEquals(POINTS.size(), numAccurate2);
9090
}

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.util.Random
2222
import org.scalatest.FunSuite
2323

2424
import org.apache.spark.SparkException
25-
import org.apache.spark.mllib.linalg.Vectors
25+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2626
import org.apache.spark.mllib.regression.LabeledPoint
2727
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
2828

@@ -63,12 +63,12 @@ object NaiveBayesSuite {
6363
class NaiveBayesSuite extends FunSuite with LocalSparkContext {
6464

6565
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
66-
val numOfPredictions = predictions.zip(input).count {
66+
val numOfWrongPredictions = predictions.zip(input).count {
6767
case (prediction, expected) =>
6868
prediction != expected.label
6969
}
70-
// At least 80% of the predictions should be on.
71-
assert(numOfPredictions < input.length / 5)
70+
// At least 80% of the predictions should be correct.
71+
assert(numOfWrongPredictions < input.length / 5)
7272
}
7373

7474
test("Naive Bayes") {
@@ -97,6 +97,51 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
9797
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
9898
}
9999

100+
test("distributed naive bayes") {
101+
val nPoints = 10000
102+
val nLabels = 150
103+
val nFeatures = 300
104+
105+
def logNormalize(s: Seq[Int]) = {
106+
s.map(_.toDouble / s.sum).map(math.log)
107+
}
108+
109+
val pi = logNormalize(1 to nLabels).toArray
110+
val theta = (for(l <- 1 to nLabels; f <- 1 to nFeatures)
111+
yield if (f == l) 10000 else 1 // Each label is dominated by a different feature.
112+
).grouped(nFeatures).map(logNormalize).map(_.toArray).toArray
113+
114+
val trainData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
115+
val trainRDD = sc.parallelize(trainData, 1)
116+
trainRDD.cache()
117+
118+
val model = NaiveBayes.train(trainRDD, 1.0, "dist")
119+
120+
val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17)
121+
val validationRDD = sc.parallelize(validationData, 2)
122+
123+
// Test prediction on RDD.
124+
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
125+
126+
// Test prediction on Array.
127+
val shortValData = validationData.take(nPoints / 10)
128+
validatePrediction(shortValData.map(row => model.predict(row.features)), shortValData)
129+
}
130+
131+
test("distributed naive bayes with empty train RDD") {
132+
val emptyTrainRDD = sc.parallelize(new Array[LabeledPoint](0), 2)
133+
intercept[SparkException] {
134+
NaiveBayes.train(emptyTrainRDD, 1.0, "dist")
135+
}
136+
}
137+
138+
test("distributed naive bayes with empty test RDD") {
139+
val trainRDD = sc.parallelize(LabeledPoint(1.0, Vectors.dense(2.0)) :: Nil, 2)
140+
val model = NaiveBayes.train(trainRDD, 1.0, "dist")
141+
val emptyTestRDD = sc.parallelize(new Array[Vector](0), 2)
142+
assert(model.predict(emptyTestRDD).count == 0)
143+
}
144+
100145
test("detect negative values") {
101146
val dense = Seq(
102147
LabeledPoint(1.0, Vectors.dense(1.0)),

0 commit comments

Comments
 (0)