Skip to content

Commit b1178cf

Browse files
committed
fit into the optimizer framework
1 parent dbe3cff commit b1178cf

File tree

3 files changed

+274
-298
lines changed

3 files changed

+274
-298
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,35 +78,28 @@ class LDA private (
7878
*
7979
* This is the parameter to a symmetric Dirichlet distribution.
8080
*/
81-
def getDocConcentration: Double = {
82-
if (this.docConcentration == -1) {
83-
(50.0 / k) + 1.0
84-
} else {
85-
this.docConcentration
86-
}
87-
}
81+
def getDocConcentration: Double = this.docConcentration
8882

8983
/**
9084
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
9185
* distributions over topics ("theta").
9286
*
93-
* This is the parameter to a symmetric Dirichlet distribution.
87+
* This is the parameter to a symmetric Dirichlet distribution, where larger values
88+
* mean more smoothing (more regularization).
9489
*
95-
* This value should be > 1.0, where larger values mean more smoothing (more regularization).
9690
* If set to -1, then docConcentration is set automatically.
9791
* (default = -1 = automatic)
9892
*
9993
* Automatic setting of parameter:
10094
* - For EM: default = (50 / k) + 1.
10195
* - The 50/k is common in LDA libraries.
10296
* - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
97+
* - For Online: default = (1.0 / k).
98+
* - follows the implementation from: https://github.com/Blei-Lab/onlineldavb.
10399
*
104-
* Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
105-
* but values in (0,1) are not yet supported.
100+
* Note: For EM optimizer, This value should be > 1.0.
106101
*/
107102
def setDocConcentration(docConcentration: Double): this.type = {
108-
require(docConcentration > 1.0 || docConcentration == -1.0,
109-
s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration")
110103
this.docConcentration = docConcentration
111104
this
112105
}
@@ -126,13 +119,7 @@ class LDA private (
126119
* Note: The topics' distributions over terms are called "beta" in the original LDA paper
127120
* by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
128121
*/
129-
def getTopicConcentration: Double = {
130-
if (this.topicConcentration == -1) {
131-
1.1
132-
} else {
133-
this.topicConcentration
134-
}
135-
}
122+
def getTopicConcentration: Double = this.topicConcentration
136123

137124
/**
138125
* Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
@@ -143,21 +130,19 @@ class LDA private (
143130
* Note: The topics' distributions over terms are called "beta" in the original LDA paper
144131
* by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
145132
*
146-
* This value should be > 0.0.
147133
* If set to -1, then topicConcentration is set automatically.
148134
* (default = -1 = automatic)
149135
*
150136
* Automatic setting of parameter:
151137
* - For EM: default = 0.1 + 1.
152138
* - The 0.1 gives a small amount of smoothing.
153139
* - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
140+
* - For Online: default = (1.0 / k).
141+
* - follows the implementation from: https://github.com/Blei-Lab/onlineldavb.
154142
*
155-
* Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
156-
* but values in (0,1) are not yet supported.
143+
* Note: For EM optimizer, This value should be > 1.0.
157144
*/
158145
def setTopicConcentration(topicConcentration: Double): this.type = {
159-
require(topicConcentration > 1.0 || topicConcentration == -1.0,
160-
s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration")
161146
this.topicConcentration = topicConcentration
162147
this
163148
}
@@ -245,8 +230,7 @@ class LDA private (
245230
* @return Inferred LDA model
246231
*/
247232
def run(documents: RDD[(Long, Vector)]): LDAModel = {
248-
val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
249-
seed, checkpointInterval)
233+
val state = ldaOptimizer.initialize(documents, this)
250234
var iter = 0
251235
val iterationTimes = Array.fill[Double](maxIterations)(0)
252236
while (iter < maxIterations) {

0 commit comments

Comments
 (0)