Skip to content

Commit 2a821a6

Browse files
author
Feynman Liang
committed
Add predict methods to LocalLDAModel
1 parent a200e64 commit 2a821a6

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,28 @@ class LocalLDAModel private[clustering] (
297297
score
298298
}
299299

300+
/**
301+
* Predicts the topic mixture distribution ("gamma") for a document.
302+
*/
303+
def topicDistribution(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
304+
// Double transpose because dirichletExpectation normalizes by row and we need to normalize
305+
// by topic (columns of lambda)
306+
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
307+
val topicConcentrationBrz = this.docConcentration.toBreeze
308+
val gammaShape = this.gammaShape
309+
val k = this.k
310+
311+
documents.map { doc =>
312+
val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
313+
doc._2,
314+
expElogbeta,
315+
topicConcentrationBrz,
316+
gammaShape,
317+
k)
318+
(doc._1, Vectors.dense((gamma / sum(gamma)).toArray))
319+
}
320+
}
321+
300322
}
301323

302324

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
242242
val alpha = 0.01
243243
val eta = 0.01
244244
val gammaShape = 100
245+
// obtained from LDA model trained in gensim, see below
245246
val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
246247
1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
247248
0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
@@ -281,6 +282,67 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
281282
assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D)
282283
}
283284

285+
test("LocalLDAModel predict") {
286+
val k = 2
287+
val vocabSize = 6
288+
val alpha = 0.01
289+
val eta = 0.01
290+
val gammaShape = 100
291+
// obtained from LDA model trained in gensim, see below
292+
val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
293+
1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
294+
0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
295+
296+
def toydata: Array[(Long, Vector)] = Array(
297+
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
298+
Vectors.sparse(6, Array(1, 2), Array(1, 1)),
299+
Vectors.sparse(6, Array(0, 2), Array(1, 1)),
300+
Vectors.sparse(6, Array(3, 4), Array(1, 1)),
301+
Vectors.sparse(6, Array(3, 5), Array(1, 1)),
302+
Vectors.sparse(6, Array(4, 5), Array(1, 1))
303+
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
304+
val docs = sc.parallelize(toydata)
305+
306+
307+
val ldaModel: LocalLDAModel = new LocalLDAModel(
308+
topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
309+
310+
/* Verify results using gensim:
311+
import numpy as np
312+
from gensim import models
313+
corpus = [
314+
[(0, 1.0), (1, 1.0)],
315+
[(1, 1.0), (2, 1.0)],
316+
[(0, 1.0), (2, 1.0)],
317+
[(3, 1.0), (4, 1.0)],
318+
[(3, 1.0), (5, 1.0)],
319+
[(4, 1.0), (5, 1.0)]]
320+
np.random.seed(2345)
321+
lda = models.ldamodel.LdaModel(
322+
corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
323+
decay=0.51, offset=1024)
324+
print(list(lda.get_document_topics(corpus)))
325+
> [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)],
326+
> [(0, 0.99504950495049516)], [(1, 0.99504950495049516)],
327+
> [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]]
328+
*/
329+
330+
val expectedPredictions = List(
331+
(0, 0.99504), (0, 0.99504),
332+
(0, 0.99504), (1, 0.99504),
333+
(1, 0.99504), (1, 0.99504))
334+
335+
expectedPredictions.zip(
336+
ldaModel.topicDistribution(docs).map { case (_, topics) =>
337+
// convert results to expectedPredictions format, which only has highest probability topic
338+
val topicsBz = topics.toBreeze.toDenseVector
339+
(argmax(topicsBz), max(topicsBz))
340+
}.collect())
341+
.forall { case (expected, actual) =>
342+
expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
343+
}
344+
}
345+
284346
test("OnlineLDAOptimizer with asymmetric prior") {
285347
def toydata: Array[(Long, Vector)] = Array(
286348
Vectors.sparse(6, Array(0, 1), Array(1, 1)),

0 commit comments

Comments
 (0)