@@ -249,32 +249,24 @@ class LDA private (
249249
250250
251251 /**
252+ * TODO: add API to take documents paths once tokenizer is ready.
252253 * Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
253254 *
254255 * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
255256 * The term count vectors are "bags of words" with a fixed-size vocabulary
256257 * (where the vocabulary size is the length of the vector).
257258 * Document IDs must be unique and >= 0.
258- * @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
259- * -1 for automatic batchNumber.
259+ * @param batchNumber Number of batches to split input corpus . For each batch, recommendation
260+ * size is [4, 16384]. -1 for automatic batchNumber.
260261 * @return Inferred LDA model
261262 */
262263 def runOnlineLDA (documents : RDD [(Long , Vector )], batchNumber : Int = - 1 ): LDAModel = {
263- val D = documents.count().toInt
264- val batchSize =
265- if (batchNumber == - 1 ) { // auto mode
266- if (D / 100 > 16384 ) 16384
267- else if (D / 100 < 4 ) 4
268- else D / 100
269- }
270- else {
271- require(batchNumber > 0 , " batchNumber should be positive or -1" )
272- D / batchNumber
273- }
264+ require(batchNumber > 0 || batchNumber == - 1 ,
265+ s " batchNumber must be greater or -1, but was set to $batchNumber" )
274266
275- val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, batchSize )
276- ( 0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next() )
277- new LocalLDAModel (Matrices .fromBreeze(onlineLDA.lambda ).transpose)
267+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, batchNumber )
268+ val model = onlineLDA.optimize( )
269+ new LocalLDAModel (Matrices .fromBreeze(model ).transpose)
278270 }
279271
280272 /** Java-friendly version of [[run() ]] */
@@ -437,39 +429,54 @@ private[clustering] object LDA {
437429 private [clustering] class OnlineLDAOptimizer (
438430 private val documents : RDD [(Long , Vector )],
439431 private val k : Int ,
440- private val batchSize : Int ) extends Serializable {
432+ private val batchNumber : Int ) extends Serializable {
441433
442434 private val vocabSize = documents.first._2.size
443435 private val D = documents.count().toInt
444- val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
436+ private val batchSize =
437+ if (batchNumber == - 1 ) { // auto mode
438+ if (D / 100 > 16384 ) 16384
439+ else if (D / 100 < 4 ) 4
440+ else D / 100
441+ }
442+ else {
443+ D / batchNumber
444+ }
445445
446446 // Initialize the variational distribution q(beta|lambda)
447- var lambda = getGammaMatrix(k, vocabSize) // K * V
447+ private var lambda = getGammaMatrix(k, vocabSize) // K * V
448448 private var Elogbeta = dirichlet_expectation(lambda) // K * V
449449 private var expElogbeta = exp(Elogbeta ) // K * V
450450
451- private var batchId = 0
452- def next (): Unit = {
453- require(batchId < actualBatchNumber)
454- // weight of the mini-batch. 1024 down weights early iterations
455- val weight = math.pow(1024 + batchId, - 0.5 )
456- val batch = documents.sample(true , batchSize.toDouble / D )
457- batch.cache()
458- // Given a mini-batch of documents, estimates the parameters gamma controlling the
459- // variational distribution over the topic weights for each document in the mini-batch.
460- var stat = BDM .zeros[Double ](k, vocabSize)
461- stat = batch.aggregate(stat)(seqOp, _ += _)
462- stat = stat :* expElogbeta
451+ def optimize (): BDM [Double ] = {
452+ val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
453+ for (i <- 1 to actualBatchNumber){
454+ val batch = documents.sample(true , batchSize.toDouble / D )
455+
456+ // Given a mini-batch of documents, estimates the parameters gamma controlling the
457+ // variational distribution over the topic weights for each document in the mini-batch.
458+ var stat = BDM .zeros[Double ](k, vocabSize)
459+ stat = batch.treeAggregate(stat)(gradient, _ += _)
460+ update(stat, i)
461+ }
462+ lambda
463+ }
464+
465+ private def update (raw : BDM [Double ], iter: Int ): Unit = {
466+ // weight of the mini-batch. 1024 helps down weights early iterations
467+ val weight = math.pow(1024 + iter, - 0.5 )
468+
469+ // This step finishes computing the sufficient statistics for the M step
470+ val stat = raw :* expElogbeta
463471
464472 // Update lambda based on documents.
465473 lambda = lambda * (1 - weight) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * weight
466474 Elogbeta = dirichlet_expectation(lambda)
467475 expElogbeta = exp(Elogbeta )
468- batchId += 1
469476 }
470477
471478 // for each document d update that document's gamma and phi
472- private def seqOp (stat : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
479+ private def gradient (stat : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
473480 val termCounts = doc._2
474481 val (ids, cts) = termCounts match {
475482 case v : DenseVector => (((0 until v.size).toList), v.values)
@@ -488,7 +495,7 @@ private[clustering] object LDA {
488495 val ctsVector = new BDV [Double ](cts).t // 1 * ids
489496
490497 // Iterate between gamma and phi until convergence
491- while (meanchange > 1e-6 ) {
498+ while (meanchange > 1e-5 ) {
492499 val lastgamma = gammad
493500 // 1*K 1 * ids ids * k
494501 gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0 / k
0 commit comments