Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
private var tau0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.05
private var optimizeAlpha: Boolean = false
private var optimizeDocConcentration: Boolean = false

// internal data structure
private var docs: RDD[(Long, Vector)] = null
Expand Down Expand Up @@ -335,20 +335,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}

/**
* Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution)
* will be optimized during training.
* Optimize docConcentration, indicates whether docConcentration (Dirichlet parameter for
* document-topic distribution) will be optimized during training.
*/
@Since("1.5.0")
def getOptimzeAlpha: Boolean = this.optimizeAlpha
def getOptimizeDocConcentration: Boolean = this.optimizeDocConcentration

/**
* Sets whether to optimize alpha parameter during training.
* Sets whether to optimize docConcentration parameter during training.
*
* Default: false
*/
@Since("1.5.0")
def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = {
this.optimizeAlpha = optimizeAlpha
def setOptimizeDocConcentration(optimizeDocConcentration: Boolean): this.type = {
this.optimizeDocConcentration = optimizeDocConcentration
this
}

Expand Down Expand Up @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {

// Note that this is an optimization to avoid batch.count
updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
if (optimizeAlpha) updateAlpha(gammat)
if (optimizeDocConcentration) updateAlpha(gammat)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
val k = 2
val docs = sc.parallelize(toyData)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
.setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false)
.setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false)
val lda = new LDA().setK(k)
.setDocConcentration(1D / k)
.setTopicConcentration(0.01)
Expand Down