Skip to content

Commit b21b9c1

Browse files
committed
Collapsed gibbs sampling based Latent Dirichlet Allocation
1 parent 725715c commit b21b9c1

File tree

4 files changed

+478
-1
lines changed

4 files changed

+478
-1
lines changed

core/pom.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
<dependency>
8686
<groupId>org.apache.commons</groupId>
8787
<artifactId>commons-math3</artifactId>
88-
<version>3.3</version>
8988
<scope>test</scope>
9089
</dependency>
9190
<dependency>

mllib/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@
7171
</exclusion>
7272
</exclusions>
7373
</dependency>
74+
<dependency>
75+
<groupId>org.apache.commons</groupId>
76+
<artifactId>commons-math3</artifactId>
77+
<scope>test</scope>
78+
</dependency>
7479
<dependency>
7580
<groupId>org.scalatest</groupId>
7681
<artifactId>scalatest_${scala.binary.version}</artifactId>
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import java.util.Random
21+
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
22+
23+
import org.apache.spark.Logging
24+
import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV}
25+
import org.apache.spark.mllib.linalg.Vectors
26+
import org.apache.spark.rdd.RDD
27+
28+
case class Document(docId: Int, content: Iterable[Int], var topics: Iterable[Int] = null,
29+
var topicDist: BV[Double] = null)
30+
31+
class TopicModel(val topicCounts_ : BDV[Double],
32+
val topicTermCounts_ : Array[BSV[Double]],
33+
val alpha: Double,
34+
val beta: Double)
35+
extends Serializable {
36+
37+
def this(topicCounts_ : SDV, topicTermCounts_ : Array[SSV], alpha: Double, beta: Double) =
38+
this(new BDV[Double](topicCounts_.toArray), topicTermCounts_.map(t =>
39+
new BSV(t.indices, t.values, t.size)), alpha, beta)
40+
41+
def topicCounts = Vectors.dense(topicCounts_.toArray)
42+
43+
def topicTermCounts = topicTermCounts_.map(t => Vectors.sparse(t.size, t.activeIterator.toSeq))
44+
45+
def update(term: Int, topic: Int, inc: Int) = {
46+
topicCounts_(topic) += inc
47+
topicTermCounts_(topic)(term) += inc
48+
this
49+
}
50+
51+
def merge(other: TopicModel) = {
52+
topicCounts_ += other.topicCounts_
53+
var i = 0
54+
while (i < topicTermCounts_.length) {
55+
topicTermCounts_(i) += other.topicTermCounts_(i)
56+
i += 1
57+
}
58+
this
59+
}
60+
61+
def phi(topic: Int, term: Int): Double = {
62+
val numTerms = topicTermCounts_.head.size
63+
(topicTermCounts_(topic)(term) + beta) / (topicCounts_(topic) + numTerms * beta)
64+
}
65+
66+
def infer(doc: Document, rand: Random, totalIter: Int = 10, burnInIter: Int = 5): BV[Double] = {
67+
require(totalIter > burnInIter, "totalIter is less than burnInIter")
68+
require(totalIter > 0, "totalIter is less than 0")
69+
require(burnInIter > 0, "burnInIter is less than 0")
70+
71+
val Document(_, content, _, topicDist) = doc
72+
val numTopics = topicCounts_.size
73+
var lastTopicDist = BSV.zeros[Double](numTopics)
74+
var currentTopicDist = topicDist
75+
val probDist = BSV.zeros[Double](numTopics)
76+
77+
for (i <- 1 to totalIter) {
78+
if (currentTopicDist != null) {
79+
content.foreach { term =>
80+
val dist = generateTopicDistForTerm(currentTopicDist, term)
81+
val lastTopic = CGS.multinomialDistSampler(rand, dist)
82+
lastTopicDist(lastTopic) += 1
83+
}
84+
}
85+
else {
86+
content.foreach { term =>
87+
val lastTopic = CGS.uniformDistSampler(rand, numTopics)
88+
lastTopicDist(lastTopic) += 1
89+
}
90+
}
91+
92+
if (i > burnInIter) {
93+
probDist :+= lastTopicDist
94+
}
95+
currentTopicDist = lastTopicDist
96+
lastTopicDist = BSV.zeros[Double](numTopics)
97+
}
98+
99+
probDist :/= (totalIter - burnInIter).toDouble
100+
probDist
101+
}
102+
103+
104+
/**
105+
* This function used for computing the new distribution after drop one from current document,
106+
* which is a really essential part of Gibbs sampling for LDA, you can refer to the paper:
107+
* <I>Parameter estimation for text analysis<I>
108+
*/
109+
def dropOneDistSampler(docTopicCount: BV[Double], rand: Random, term: Int,
110+
currentTopic: Int): Int = {
111+
val topicThisTerm = generateTopicDistForTerm(docTopicCount, term,
112+
currentTopic, isTrainModel = true)
113+
CGS.multinomialDistSampler(rand, topicThisTerm)
114+
}
115+
116+
def generateTopicDistForTerm(docTopicCount: BV[Double], term: Int,
117+
currentTopic: Int = -1, isTrainModel: Boolean = false):
118+
BDV[Double] = {
119+
val (numTopics, numTerms) = (topicCounts_.size, topicTermCounts_.head.size)
120+
val topicThisTerm = BDV.zeros[Double](numTopics)
121+
var i = 0
122+
while (i < numTopics) {
123+
val adjustment = if (isTrainModel && i == currentTopic) -1 else 0
124+
topicThisTerm(i) = (topicTermCounts_(i)(term) + adjustment + beta) /
125+
(topicCounts_(i) + adjustment + (numTerms * beta)) *
126+
(docTopicCount(i) + adjustment + alpha)
127+
i += 1
128+
}
129+
topicThisTerm
130+
}
131+
}
132+
133+
object TopicModel {
134+
def apply(numTopics: Int, numTerms: Int, alpha: Double = 0.1,
135+
beta: Double = 0.01) = new TopicModel(
136+
BDV.zeros[Double](numTopics),
137+
Array(0 until numTopics: _*).map(_ => BSV.zeros[Double](numTerms)),
138+
alpha, beta)
139+
140+
}
141+
142+
class LDA private(
143+
var numTopics: Int,
144+
var numTerms: Int,
145+
totalIter: Int,
146+
burnInIter: Int,
147+
var alpha: Double,
148+
var beta: Double)
149+
extends Serializable with Logging {
150+
def run(input: RDD[Document]): (TopicModel, RDD[Document]) = {
151+
val initModel = TopicModel(numTopics, numTerms, alpha, beta)
152+
CGS.runGibbsSampling(input, initModel, totalIter, burnInIter)
153+
}
154+
}
155+
156+
object LDA extends Logging {
157+
def train(
158+
data: RDD[Document],
159+
numTerms: Int,
160+
numTopics: Int,
161+
totalIter: Int,
162+
burnInIter: Int,
163+
alpha: Double,
164+
beta: Double):
165+
(TopicModel, RDD[Document]) = {
166+
val lda = new LDA(numTopics, numTerms, totalIter, burnInIter, alpha, beta)
167+
lda.run(data)
168+
}
169+
170+
/**
171+
* Perplexity is a kind of evaluation method of LDA. Usually it is used on unseen data. But here
172+
* we use it for current documents, which is also OK. If using it on unseen data, you must do an
173+
* iteration of Gibbs sampling before calling this. Small perplexity means good result.
174+
*/
175+
def perplexity(data: RDD[Document], computedModel: TopicModel): Double = {
176+
val broadcastModel = data.context.broadcast(computedModel)
177+
val (termProb, totalNum) = data.mapPartitions { docs =>
178+
val model = broadcastModel.value
179+
val numTopics = model.topicCounts_.size
180+
val numTerms = model.topicTermCounts_.head.size
181+
val rand = new Random
182+
val alpha = model.alpha
183+
docs.flatMap { case (doc@Document(docId, content, topics, topicDist)) =>
184+
val currentTheta = BSV.zeros[Double](numTerms)
185+
val theta = model.infer(doc, rand)
186+
content.foreach { term =>
187+
(0 until numTopics).foreach { topic =>
188+
currentTheta(term) += model.phi(topic, term) * ((theta(topic) + alpha) /
189+
(content.size + alpha * numTopics))
190+
}
191+
}
192+
content.map(x => (math.log(currentTheta(x)), 1))
193+
}
194+
}.reduce { (lhs, rhs) =>
195+
(lhs._1 + rhs._1, lhs._2 + rhs._2)
196+
}
197+
math.exp(-1 * termProb / totalNum)
198+
}
199+
200+
}
201+
202+
/**
203+
* Collapsed Gibbs sampling from for Latent Dirichlet Allocation.
204+
*/
205+
object CGS extends Logging {
206+
207+
/**
208+
* Main function of running a Gibbs sampling method. It contains two phases of total Gibbs
209+
* sampling: first is initialization, second is real sampling.
210+
*/
211+
def runGibbsSampling(data: RDD[Document], initModel: TopicModel,
212+
totalIter: Int, burnInIter: Int): (TopicModel, RDD[Document]) = {
213+
require(totalIter > burnInIter, "totalIter is less than burnInIter")
214+
require(totalIter > 0, "totalIter is less than 0")
215+
require(burnInIter > 0, "burnInIter is less than 0")
216+
217+
val (numTopics, numTerms, alpha, beta) = (initModel.topicCounts_.size,
218+
initModel.topicTermCounts_.head.size,
219+
initModel.alpha, initModel.beta)
220+
val probModel = TopicModel(numTopics, numTerms, alpha, beta)
221+
222+
// Construct topic assignment RDD
223+
logInfo("Start initialization")
224+
var (params, docTopics) = sampleTermAssignment(data, initModel)
225+
226+
for (iter <- 1 to totalIter) {
227+
logInfo("Start Gibbs sampling (Iteration %d/%d)".format(iter, totalIter))
228+
val broadcastParams = data.context.broadcast(params)
229+
val previousDocTopics = docTopics
230+
docTopics = docTopics.mapPartitions { docs =>
231+
val rand = new Random
232+
val currentParams = broadcastParams.value
233+
docs.map { case Document(docId, content, topics, topicDist) =>
234+
val chosenTopicCounts: BV[Double] = BSV.zeros[Double](numTopics)
235+
val chosenTopics = content.zip(topics).map { case (term, topic) =>
236+
val chosenTopic = currentParams.dropOneDistSampler(topicDist, rand, term, topic)
237+
if (topic != chosenTopic) {
238+
topicDist(topic) += -1
239+
currentParams.update(term, topic, -1)
240+
currentParams.update(term, chosenTopic, 1)
241+
topicDist(chosenTopic) += 1
242+
}
243+
chosenTopicCounts(chosenTopic) += 1
244+
chosenTopic
245+
}
246+
Document(docId, content, chosenTopics, chosenTopicCounts)
247+
}
248+
}.setName(s"LDA-$iter").cache()
249+
250+
if (iter % 20 == 0 && data.context.getCheckpointDir.isDefined) {
251+
docTopics.checkpoint()
252+
}
253+
254+
params = collectTopicCounters(docTopics, numTerms, numTopics)
255+
256+
if (iter > burnInIter) {
257+
probModel.merge(params)
258+
}
259+
previousDocTopics.unpersist()
260+
}
261+
val burnIn = (totalIter - burnInIter).toDouble
262+
probModel.topicCounts_ :/= burnIn
263+
probModel.topicTermCounts_.foreach(_ :/= burnIn)
264+
(probModel, docTopics)
265+
}
266+
267+
private def collectTopicCounters(docTopics: RDD[Document], numTerms: Int, numTopics: Int):
268+
TopicModel = {
269+
docTopics.mapPartitions { iter =>
270+
val topicCounters = TopicModel(numTopics, numTerms)
271+
iter.foreach { doc =>
272+
doc.content.zip(doc.topics).foreach(t => topicCounters.update(t._1, t._2, 1))
273+
}
274+
Iterator(topicCounters)
275+
}.fold(TopicModel(numTopics, numTerms)) { (thatOne, otherOne) =>
276+
thatOne.merge(otherOne)
277+
}
278+
}
279+
280+
/**
281+
* Initial step of Gibbs sampling, which supports incremental LDA.
282+
*/
283+
private def sampleTermAssignment(data: RDD[Document], topicModel: TopicModel):
284+
(TopicModel, RDD[Document]) = {
285+
val (numTopics, numTerms, alpha, beta) = (topicModel.topicCounts_.size,
286+
topicModel.topicTermCounts_.head.size,
287+
topicModel.alpha, topicModel.beta)
288+
val broadcastParams = data.context.broadcast(topicModel)
289+
290+
val initialDocs = if (topicModel.topicCounts_.norm(2) == 0) {
291+
data.mapPartitions { docs =>
292+
val rand = new Random
293+
docs.map { case Document(docId, content, topics, topicDist) =>
294+
val lastDocTopicCount = BSV.zeros[Double](numTopics)
295+
val lastTopics = content.map { term =>
296+
val topic = uniformDistSampler(rand, numTopics)
297+
lastDocTopicCount(topic) += 1
298+
topic
299+
}
300+
Document(docId, content, lastTopics, lastDocTopicCount)
301+
}
302+
}.cache()
303+
} else {
304+
data.mapPartitions { docs =>
305+
val rand = new Random
306+
val currentParams = broadcastParams.value
307+
docs.map { case Document(docId, content, topics, topicDist) =>
308+
val lastDocTopicCount = BSV.zeros[Double](numTopics)
309+
val lastTopics = content.map { term =>
310+
val dist = currentParams.generateTopicDistForTerm(topicDist, term)
311+
val lastTopic = multinomialDistSampler(rand, dist)
312+
lastDocTopicCount(lastTopic) += 1
313+
lastTopic
314+
}
315+
Document(docId, content, lastTopics, lastDocTopicCount)
316+
}
317+
}.cache()
318+
}
319+
320+
val initialModel = collectTopicCounters(initialDocs, numTerms, numTopics)
321+
(initialModel, initialDocs)
322+
}
323+
324+
/**
325+
* A uniform distribution sampler, which is only used for initialization.
326+
*/
327+
def uniformDistSampler(rand: Random, dimension: Int): Int = rand.nextInt(dimension)
328+
329+
/**
330+
* A multinomial distribution sampler, using roulette method to sample an Int back.
331+
*/
332+
def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = {
333+
val distSum = rand.nextDouble() * breeze.linalg.sum[BDV[Double], Double](dist)
334+
335+
def loop(index: Int, accum: Double): Int = {
336+
if (index == dist.length) return dist.length - 1
337+
val sum = accum - dist(index)
338+
if (sum <= 0) index else loop(index + 1, sum)
339+
}
340+
341+
loop(0, distSum)
342+
}
343+
}

0 commit comments

Comments
 (0)