@@ -24,19 +24,25 @@ import org.apache.spark.SparkContext._
2424import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
2525import org .apache .spark .mllib .regression .LabeledPoint
2626import org .apache .spark .rdd .RDD
27+ import org .apache .spark .storage .StorageLevel
2728
2829/**
29- * Model for Naive Bayes Classifiers.
30+ * Abstract model for a naive bayes classifier.
31+ */
32+ abstract class NaiveBayesModel extends ClassificationModel with Serializable
33+
34+ /**
35+ * Local model for a naive bayes classifier.
3036 *
3137 * @param labels list of labels
3238 * @param pi log of class priors, whose dimension is C, number of labels
3339 * @param theta log of class conditional probabilities, whose dimension is C-by-D,
3440 * where D is number of features
3541 */
36- class NaiveBayesModel private [mllib] (
42+ private class LocalNaiveBayesModel (
3743 val labels : Array [Double ],
3844 val pi : Array [Double ],
39- val theta : Array [Array [Double ]]) extends ClassificationModel with Serializable {
45+ val theta : Array [Array [Double ]]) extends NaiveBayesModel {
4046
4147 private val brzPi = new BDV [Double ](pi)
4248 private val brzTheta = new BDM [Double ](theta.length, theta(0 ).length)
@@ -67,6 +73,54 @@ class NaiveBayesModel private[mllib] (
6773 }
6874}
6975
76+ /**
77+ * One block from a distributed model for a naive bayes classifier. The model is divided into
78+ * blocks, each containing the complete model state for a group of labels.
79+ *
80+ * @param labels array of labels
81+ * @param pi log of class priors, with dimension C, the number of labels in this block
82+ * @param theta log of class conditional probabilities, with dimensions C-by-D,
83+ * where D is the number of features
84+ */
85+ private case class NBModelBlock (labels : Array [Double ], pi : BDV [Double ], theta : BDM [Double ])
86+
87+ /**
88+ * Distributed model for a naive bayes classifier.
89+ *
90+ * @param modelBlocks RDD of NBModelBlock, comprising the model
91+ */
92+ private class DistNaiveBayesModel (val modelBlocks : RDD [NBModelBlock ]) extends NaiveBayesModel {
93+
94+ override def predict (testData : RDD [Vector ]): RDD [Double ] = {
95+ // Pair each test data point with all model blocks.
96+ val testXModel = testData.map(_.toBreeze).zipWithIndex().cartesian(modelBlocks)
97+
98+ // Find the maximum a posteriori label for each (test_data_point, model_block) pair.
99+ val testXModelMaxes = testXModel.map { case ((test, i), model) => {
100+ val posterior = model.pi + model.theta * test
101+ val maxIdx = brzArgmax(posterior)
102+ (i, (posterior(maxIdx), model.labels(maxIdx)))
103+ }}
104+
105+ // Find the maximum for each test data point, across all model blocks.
106+ val testMaxes = testXModelMaxes.reduceByKey(Ordering [(Double ,Double )].max)
107+
108+ // Reorder based on the original testData index, then project the labels.
109+ testMaxes.sortByKey().map{ case (_, (_, label)) => label }
110+ }
111+
112+ override def predict (testData : Vector ): Double = {
113+ val testBreeze = testData.toBreeze
114+
115+ // Find the max a posteriori label for each model block, then the max of these block maxes.
116+ modelBlocks.map( m => {
117+ val posterior = m.pi + m.theta * testBreeze
118+ val maxIdx = brzArgmax(posterior)
119+ (posterior(maxIdx), m.labels(maxIdx))
120+ }).max._2
121+ }
122+ }
123+
70124/**
71125 * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
72126 *
@@ -77,6 +131,8 @@ class NaiveBayesModel private[mllib] (
77131 */
78132class NaiveBayes private (private var lambda : Double ) extends Serializable with Logging {
79133
134+ private var distMode = " local"
135+
80136 def this () = this (1.0 )
81137
82138 /** Set the smoothing parameter. Default: 1.0. */
@@ -85,6 +141,12 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
85141 this
86142 }
87143
144+ /** Set the model distribution mode, either "local" or "dist" (for distributed). */
145+ def setDistMode (distMode : String ) = {
146+ this .distMode = distMode
147+ this
148+ }
149+
88150 /**
89151 * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
90152 *
@@ -103,10 +165,8 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
103165 }
104166 }
105167
106- // Aggregates term frequencies per label.
107- // TODO: Calling combineByKey and collect creates two stages, we can implement something
108- // TODO: similar to reduceByKeyLocally to save one stage.
109- val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long , BDV [Double ])](
168+ // Sum the document counts and feature frequencies for each label.
169+ val labelAggregates = data.map(p => (p.label, p.features)).combineByKey[(Long , BDV [Double ])](
110170 createCombiner = (v : Vector ) => {
111171 requireNonnegativeValues(v)
112172 (1L , v.toBreeze.toDenseVector)
@@ -117,7 +177,20 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
117177 },
118178 mergeCombiners = (c1 : (Long , BDV [Double ]), c2 : (Long , BDV [Double ])) =>
119179 (c1._1 + c2._1, c1._2 += c2._2)
120- ).collect()
180+ )
181+
182+ distMode match {
183+ case " local" => trainLocalModel(labelAggregates)
184+ case " dist" => trainDistModel(labelAggregates)
185+ case _ =>
186+ throw new SparkException (s " Naive Bayes requires a valid distMode but found $distMode. " )
187+ }
188+ }
189+
190+ private def trainLocalModel (labelAggregates : RDD [(Double , (Long , BDV [Double ]))]) = {
191+ // TODO: Calling combineByKey and collect creates two stages, we can implement something
192+ // TODO: similar to reduceByKeyLocally to save one stage.
193+ val aggregated = labelAggregates.collect()
121194 val numLabels = aggregated.length
122195 var numDocuments = 0L
123196 aggregated.foreach { case (_, (n, _)) =>
@@ -141,7 +214,41 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
141214 i += 1
142215 }
143216
144- new NaiveBayesModel (labels, pi, theta)
217+ new LocalNaiveBayesModel (labels, pi, theta)
218+ }
219+
220+ private def trainDistModel (labelAggregates : RDD [(Double , (Long , BDV [Double ]))]) = {
221+ case class LabelAggregate (label : Double , numDocuments : Long , sumFeatures : BDV [Double ])
222+ val aggregated = labelAggregates.map(x => LabelAggregate (x._1, x._2._1, x._2._2))
223+
224+ // Compute the model's prior (pi) vector and conditional (theta) matrix for each batch of
225+ // labels.
226+ // NOTE In contrast to the local trainer, the piLogDenom normalization term is omitted here.
227+ // Computing this term requires an additional aggregation on 'aggregated', and because the
228+ // term is an additive constant it does not affect maximum a posteriori model prediction.
229+ val modelBlocks = aggregated.mapPartitions(p => p.grouped(100 ).map { batch =>
230+ val numFeatures = batch.head.sumFeatures.length
231+ val pi = batch.map(l => math.log(l.numDocuments + lambda))
232+
233+ // Assemble values of the theta matrix in row major order.
234+ val theta = new Array [Double ](batch.length * numFeatures)
235+ batch.flatMap( l => {
236+ val thetaLogDenom = math.log(brzSum(l.sumFeatures) + numFeatures * lambda)
237+ l.sumFeatures.iterator.map(f => math.log(f._2 + lambda) - thetaLogDenom)
238+ }).copyToArray(theta)
239+
240+ NBModelBlock (labels = batch.map(_.label).toArray,
241+ pi = new BDV [Double ](pi.toArray),
242+ theta = new BDM [Double ](batch.length, numFeatures, theta,
243+ offset= 0 , majorStride= numFeatures, isTranspose= true ))
244+ })
245+
246+ // Materialize and persist the model, check that it is nonempty.
247+ if (modelBlocks.persist(StorageLevel .MEMORY_AND_DISK ).count() == 0 ) {
248+ throw new SparkException (" Naive Bayes requires a nonempty training RDD." )
249+ }
250+
251+ new DistNaiveBayesModel (modelBlocks)
145252 }
146253}
147254
@@ -177,8 +284,9 @@ object NaiveBayes {
177284 * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
178285 * vector or a count vector.
179286 * @param lambda The smoothing parameter
287+ * @param distMode The model distribution mode, either "local" or "dist" (for distributed)
180288 */
181- def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
182- new NaiveBayes (lambda).run(input)
289+ def train (input : RDD [LabeledPoint ], lambda : Double , distMode : String ): NaiveBayesModel = {
290+ new NaiveBayes (lambda).setDistMode(distMode). run(input)
183291 }
184292}
0 commit comments