@@ -227,8 +227,8 @@ class OnlineLDAOptimizer extends LDAOptimizer {
227227 private var k : Int = 0
228228 private var corpusSize : Long = 0
229229 private var vocabSize : Int = 0
230- private var alpha : Double = 0
231- private var eta : Double = 0
230+ private [clustering] var alpha : Double = 0
231+ private [clustering] var eta : Double = 0
232232 private var randomGenerator : java.util.Random = null
233233
234234 // Online LDA specific parameters
@@ -238,12 +238,11 @@ class OnlineLDAOptimizer extends LDAOptimizer {
238238
239239 // internal data structure
240240 private var docs : RDD [(Long , Vector )] = null
241- private var lambda : BDM [Double ] = null
242- private var Elogbeta : BDM [Double ] = null
243- private var expElogbeta : BDM [Double ] = null
241+ private [clustering] var lambda : BDM [Double ] = null
244242
245243 // count of invocation to next, which helps deciding the weight for each iteration
246244 private var iteration : Int = 0
245+ private var gammaShape : Double = 100
247246
248247 /**
249248 * A (positive) learning parameter that downweights early iterations. Larger values make early
@@ -295,7 +294,24 @@ class OnlineLDAOptimizer extends LDAOptimizer {
295294 this
296295 }
297296
298- override private [clustering] def initialize (docs : RDD [(Long , Vector )], lda : LDA ): LDAOptimizer = {
297+ /**
298+ * The function is for test only now. In the future, it can help support training strop/resume
299+ */
300+ private [clustering] def setLambda (lambda : BDM [Double ]): this .type = {
301+ this .lambda = lambda
302+ this
303+ }
304+
305+ /**
306+ * Used to control the gamma distribution. Larger value produces values closer to 1.0.
307+ */
308+ private [clustering] def setGammaShape (shape : Double ): this .type = {
309+ this .gammaShape = shape
310+ this
311+ }
312+
313+ override private [clustering] def initialize (docs : RDD [(Long , Vector )], lda : LDA ):
314+ OnlineLDAOptimizer = {
299315 this .k = lda.getK
300316 this .corpusSize = docs.count()
301317 this .vocabSize = docs.first()._2.size
@@ -307,26 +323,30 @@ class OnlineLDAOptimizer extends LDAOptimizer {
307323
308324 // Initialize the variational distribution q(beta|lambda)
309325 this .lambda = getGammaMatrix(k, vocabSize)
310- this .Elogbeta = dirichletExpectation(lambda)
311- this .expElogbeta = exp(Elogbeta )
312326 this .iteration = 0
313327 this
314328 }
315329
330+ override private [clustering] def next (): OnlineLDAOptimizer = {
331+ val batch = docs.sample(withReplacement = true , miniBatchFraction, randomGenerator.nextLong())
332+ if (batch.isEmpty()) return this
333+ submitMiniBatch(batch)
334+ }
335+
336+
316337 /**
317338 * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA
318339 * model, and it will update the topic distribution adaptively for the terms appearing in the
319340 * subset.
320341 */
321- override private [clustering] def next ( ): OnlineLDAOptimizer = {
342+ private [clustering] def submitMiniBatch ( batch : RDD [( Long , Vector )] ): OnlineLDAOptimizer = {
322343 iteration += 1
323- val batch = docs.sample(withReplacement = true , miniBatchFraction, randomGenerator.nextLong())
324- if (batch.isEmpty()) return this
325-
326344 val k = this .k
327345 val vocabSize = this .vocabSize
328- val expElogbeta = this .expElogbeta
346+ val Elogbeta = dirichletExpectation(lambda)
347+ val expElogbeta = exp(Elogbeta )
329348 val alpha = this .alpha
349+ val gammaShape = this .gammaShape
330350
331351 val stats : RDD [BDM [Double ]] = batch.mapPartitions { docs =>
332352 val stat = BDM .zeros[Double ](k, vocabSize)
@@ -340,7 +360,7 @@ class OnlineLDAOptimizer extends LDAOptimizer {
340360 }
341361
342362 // Initialize the variational distribution q(theta|gamma) for the mini-batch
343- var gammad = new Gamma (100 , 1.0 / 100.0 ).samplesVector(k).t // 1 * K
363+ var gammad = new Gamma (gammaShape , 1.0 / gammaShape ).samplesVector(k).t // 1 * K
344364 var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
345365 var expElogthetad = exp(Elogthetad ) // 1 * K
346366 val expElogbetad = expElogbeta(:: , ids).toDenseMatrix // K * ids
@@ -350,7 +370,7 @@ class OnlineLDAOptimizer extends LDAOptimizer {
350370 val ctsVector = new BDV [Double ](cts).t // 1 * ids
351371
352372 // Iterate between gamma and phi until convergence
353- while (meanchange > 1e-5 ) {
373+ while (meanchange > 1e-3 ) {
354374 val lastgamma = gammad
355375 // 1*K 1 * ids ids * k
356376 gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
@@ -372,7 +392,10 @@ class OnlineLDAOptimizer extends LDAOptimizer {
372392 Iterator (stat)
373393 }
374394
375- val batchResult : BDM [Double ] = stats.reduce(_ += _)
395+ val statsSum : BDM [Double ] = stats.reduce(_ += _)
396+ val batchResult = statsSum :* expElogbeta
397+
398+ // Note that this is an optimization to avoid batch.count
376399 update(batchResult, iteration, (miniBatchFraction * corpusSize).toInt)
377400 this
378401 }
@@ -384,28 +407,23 @@ class OnlineLDAOptimizer extends LDAOptimizer {
384407 /**
385408 * Update lambda based on the batch submitted. batchSize can be different for each iteration.
386409 */
387- private def update (raw : BDM [Double ], iter : Int , batchSize : Int ): Unit = {
410+ private [clustering] def update (stat : BDM [Double ], iter : Int , batchSize : Int ): Unit = {
388411 val tau_0 = this .getTau_0
389412 val kappa = this .getKappa
390413
391414 // weight of the mini-batch.
392415 val weight = math.pow(tau_0 + iter, - kappa)
393416
394- // This step finishes computing the sufficient statistics for the M step
395- val stat = raw :* expElogbeta
396-
397417 // Update lambda based on documents.
398418 lambda = lambda * (1 - weight) +
399419 (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
400- Elogbeta = dirichletExpectation(lambda)
401- expElogbeta = exp(Elogbeta )
402420 }
403421
404422 /**
405423 * Get a random matrix to initialize lambda
406424 */
407425 private def getGammaMatrix (row : Int , col : Int ): BDM [Double ] = {
408- val gammaRandomGenerator = new Gamma (100 , 1.0 / 100.0 )
426+ val gammaRandomGenerator = new Gamma (gammaShape , 1.0 / gammaShape )
409427 val temp = gammaRandomGenerator.sample(row * col).toArray
410428 new BDM [Double ](col, row, temp).t
411429 }
0 commit comments