|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.mllib.clustering |
| 19 | + |
| 20 | +import java.util.Random |
| 21 | +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} |
| 22 | + |
| 23 | +import org.apache.spark.Logging |
| 24 | +import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV} |
| 25 | +import org.apache.spark.mllib.linalg.Vectors |
| 26 | +import org.apache.spark.rdd.RDD |
| 27 | + |
| 28 | +case class Document(docId: Int, content: Iterable[Int], var topics: Iterable[Int] = null, |
| 29 | + var topicDist: BV[Double] = null) |
| 30 | + |
| 31 | +class TopicModel(val topicCounts_ : BDV[Double], |
| 32 | + val topicTermCounts_ : Array[BSV[Double]], |
| 33 | + val alpha: Double, |
| 34 | + val beta: Double) |
| 35 | + extends Serializable { |
| 36 | + |
| 37 | + def this(topicCounts_ : SDV, topicTermCounts_ : Array[SSV], alpha: Double, beta: Double) = |
| 38 | + this(new BDV[Double](topicCounts_.toArray), topicTermCounts_.map(t => |
| 39 | + new BSV(t.indices, t.values, t.size)), alpha, beta) |
| 40 | + |
| 41 | + def topicCounts = Vectors.dense(topicCounts_.toArray) |
| 42 | + |
| 43 | + def topicTermCounts = topicTermCounts_.map(t => Vectors.sparse(t.size, t.activeIterator.toSeq)) |
| 44 | + |
| 45 | + def update(term: Int, topic: Int, inc: Int) = { |
| 46 | + topicCounts_(topic) += inc |
| 47 | + topicTermCounts_(topic)(term) += inc |
| 48 | + this |
| 49 | + } |
| 50 | + |
| 51 | + def merge(other: TopicModel) = { |
| 52 | + topicCounts_ += other.topicCounts_ |
| 53 | + var i = 0 |
| 54 | + while (i < topicTermCounts_.length) { |
| 55 | + topicTermCounts_(i) += other.topicTermCounts_(i) |
| 56 | + i += 1 |
| 57 | + } |
| 58 | + this |
| 59 | + } |
| 60 | + |
| 61 | + def phi(topic: Int, term: Int): Double = { |
| 62 | + val numTerms = topicTermCounts_.head.size |
| 63 | + (topicTermCounts_(topic)(term) + beta) / (topicCounts_(topic) + numTerms * beta) |
| 64 | + } |
| 65 | + |
| 66 | + def infer(doc: Document, rand: Random, totalIter: Int = 10, burnInIter: Int = 5): BV[Double] = { |
| 67 | + require(totalIter > burnInIter, "totalIter is less than burnInIter") |
| 68 | + require(totalIter > 0, "totalIter is less than 0") |
| 69 | + require(burnInIter > 0, "burnInIter is less than 0") |
| 70 | + |
| 71 | + val Document(_, content, _, topicDist) = doc |
| 72 | + val numTopics = topicCounts_.size |
| 73 | + var lastTopicDist = BSV.zeros[Double](numTopics) |
| 74 | + var currentTopicDist = topicDist |
| 75 | + val probDist = BSV.zeros[Double](numTopics) |
| 76 | + |
| 77 | + for (i <- 1 to totalIter) { |
| 78 | + if (currentTopicDist != null) { |
| 79 | + content.foreach { term => |
| 80 | + val dist = generateTopicDistForTerm(currentTopicDist, term) |
| 81 | + val lastTopic = CGS.multinomialDistSampler(rand, dist) |
| 82 | + lastTopicDist(lastTopic) += 1 |
| 83 | + } |
| 84 | + } |
| 85 | + else { |
| 86 | + content.foreach { term => |
| 87 | + val lastTopic = CGS.uniformDistSampler(rand, numTopics) |
| 88 | + lastTopicDist(lastTopic) += 1 |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + if (i > burnInIter) { |
| 93 | + probDist :+= lastTopicDist |
| 94 | + } |
| 95 | + currentTopicDist = lastTopicDist |
| 96 | + lastTopicDist = BSV.zeros[Double](numTopics) |
| 97 | + } |
| 98 | + |
| 99 | + probDist :/= (totalIter - burnInIter).toDouble |
| 100 | + probDist |
| 101 | + } |
| 102 | + |
| 103 | + |
| 104 | + /** |
| 105 | + * This function used for computing the new distribution after drop one from current document, |
| 106 | + * which is a really essential part of Gibbs sampling for LDA, you can refer to the paper: |
| 107 | + * <I>Parameter estimation for text analysis<I> |
| 108 | + */ |
| 109 | + def dropOneDistSampler(docTopicCount: BV[Double], rand: Random, term: Int, |
| 110 | + currentTopic: Int): Int = { |
| 111 | + val topicThisTerm = generateTopicDistForTerm(docTopicCount, term, |
| 112 | + currentTopic, isTrainModel = true) |
| 113 | + CGS.multinomialDistSampler(rand, topicThisTerm) |
| 114 | + } |
| 115 | + |
| 116 | + def generateTopicDistForTerm(docTopicCount: BV[Double], term: Int, |
| 117 | + currentTopic: Int = -1, isTrainModel: Boolean = false): |
| 118 | + BDV[Double] = { |
| 119 | + val (numTopics, numTerms) = (topicCounts_.size, topicTermCounts_.head.size) |
| 120 | + val topicThisTerm = BDV.zeros[Double](numTopics) |
| 121 | + var i = 0 |
| 122 | + while (i < numTopics) { |
| 123 | + val adjustment = if (isTrainModel && i == currentTopic) -1 else 0 |
| 124 | + topicThisTerm(i) = (topicTermCounts_(i)(term) + adjustment + beta) / |
| 125 | + (topicCounts_(i) + adjustment + (numTerms * beta)) * |
| 126 | + (docTopicCount(i) + adjustment + alpha) |
| 127 | + i += 1 |
| 128 | + } |
| 129 | + topicThisTerm |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +object TopicModel { |
| 134 | + def apply(numTopics: Int, numTerms: Int, alpha: Double = 0.1, |
| 135 | + beta: Double = 0.01) = new TopicModel( |
| 136 | + BDV.zeros[Double](numTopics), |
| 137 | + Array(0 until numTopics: _*).map(_ => BSV.zeros[Double](numTerms)), |
| 138 | + alpha, beta) |
| 139 | + |
| 140 | +} |
| 141 | + |
| 142 | +class LDA private( |
| 143 | + var numTopics: Int, |
| 144 | + var numTerms: Int, |
| 145 | + totalIter: Int, |
| 146 | + burnInIter: Int, |
| 147 | + var alpha: Double, |
| 148 | + var beta: Double) |
| 149 | + extends Serializable with Logging { |
| 150 | + def run(input: RDD[Document]): (TopicModel, RDD[Document]) = { |
| 151 | + val initModel = TopicModel(numTopics, numTerms, alpha, beta) |
| 152 | + CGS.runGibbsSampling(input, initModel, totalIter, burnInIter) |
| 153 | + } |
| 154 | +} |
| 155 | + |
| 156 | +object LDA extends Logging { |
| 157 | + def train( |
| 158 | + data: RDD[Document], |
| 159 | + numTerms: Int, |
| 160 | + numTopics: Int, |
| 161 | + totalIter: Int, |
| 162 | + burnInIter: Int, |
| 163 | + alpha: Double, |
| 164 | + beta: Double): |
| 165 | + (TopicModel, RDD[Document]) = { |
| 166 | + val lda = new LDA(numTopics, numTerms, totalIter, burnInIter, alpha, beta) |
| 167 | + lda.run(data) |
| 168 | + } |
| 169 | + |
| 170 | + /** |
| 171 | + * Perplexity is a kind of evaluation method of LDA. Usually it is used on unseen data. But here |
| 172 | + * we use it for current documents, which is also OK. If using it on unseen data, you must do an |
| 173 | + * iteration of Gibbs sampling before calling this. Small perplexity means good result. |
| 174 | + */ |
| 175 | + def perplexity(data: RDD[Document], computedModel: TopicModel): Double = { |
| 176 | + val broadcastModel = data.context.broadcast(computedModel) |
| 177 | + val (termProb, totalNum) = data.mapPartitions { docs => |
| 178 | + val model = broadcastModel.value |
| 179 | + val numTopics = model.topicCounts_.size |
| 180 | + val numTerms = model.topicTermCounts_.head.size |
| 181 | + val rand = new Random |
| 182 | + val alpha = model.alpha |
| 183 | + docs.flatMap { case (doc@Document(docId, content, topics, topicDist)) => |
| 184 | + val currentTheta = BSV.zeros[Double](numTerms) |
| 185 | + val theta = model.infer(doc, rand) |
| 186 | + content.foreach { term => |
| 187 | + (0 until numTopics).foreach { topic => |
| 188 | + currentTheta(term) += model.phi(topic, term) * ((theta(topic) + alpha) / |
| 189 | + (content.size + alpha * numTopics)) |
| 190 | + } |
| 191 | + } |
| 192 | + content.map(x => (math.log(currentTheta(x)), 1)) |
| 193 | + } |
| 194 | + }.reduce { (lhs, rhs) => |
| 195 | + (lhs._1 + rhs._1, lhs._2 + rhs._2) |
| 196 | + } |
| 197 | + math.exp(-1 * termProb / totalNum) |
| 198 | + } |
| 199 | + |
| 200 | +} |
| 201 | + |
| 202 | +/** |
| 203 | + * Collapsed Gibbs sampling from for Latent Dirichlet Allocation. |
| 204 | + */ |
| 205 | +object CGS extends Logging { |
| 206 | + |
| 207 | + /** |
| 208 | + * Main function of running a Gibbs sampling method. It contains two phases of total Gibbs |
| 209 | + * sampling: first is initialization, second is real sampling. |
| 210 | + */ |
| 211 | + def runGibbsSampling(data: RDD[Document], initModel: TopicModel, |
| 212 | + totalIter: Int, burnInIter: Int): (TopicModel, RDD[Document]) = { |
| 213 | + require(totalIter > burnInIter, "totalIter is less than burnInIter") |
| 214 | + require(totalIter > 0, "totalIter is less than 0") |
| 215 | + require(burnInIter > 0, "burnInIter is less than 0") |
| 216 | + |
| 217 | + val (numTopics, numTerms, alpha, beta) = (initModel.topicCounts_.size, |
| 218 | + initModel.topicTermCounts_.head.size, |
| 219 | + initModel.alpha, initModel.beta) |
| 220 | + val probModel = TopicModel(numTopics, numTerms, alpha, beta) |
| 221 | + |
| 222 | + // Construct topic assignment RDD |
| 223 | + logInfo("Start initialization") |
| 224 | + var (params, docTopics) = sampleTermAssignment(data, initModel) |
| 225 | + |
| 226 | + for (iter <- 1 to totalIter) { |
| 227 | + logInfo("Start Gibbs sampling (Iteration %d/%d)".format(iter, totalIter)) |
| 228 | + val broadcastParams = data.context.broadcast(params) |
| 229 | + val previousDocTopics = docTopics |
| 230 | + docTopics = docTopics.mapPartitions { docs => |
| 231 | + val rand = new Random |
| 232 | + val currentParams = broadcastParams.value |
| 233 | + docs.map { case Document(docId, content, topics, topicDist) => |
| 234 | + val chosenTopicCounts: BV[Double] = BSV.zeros[Double](numTopics) |
| 235 | + val chosenTopics = content.zip(topics).map { case (term, topic) => |
| 236 | + val chosenTopic = currentParams.dropOneDistSampler(topicDist, rand, term, topic) |
| 237 | + if (topic != chosenTopic) { |
| 238 | + topicDist(topic) += -1 |
| 239 | + currentParams.update(term, topic, -1) |
| 240 | + currentParams.update(term, chosenTopic, 1) |
| 241 | + topicDist(chosenTopic) += 1 |
| 242 | + } |
| 243 | + chosenTopicCounts(chosenTopic) += 1 |
| 244 | + chosenTopic |
| 245 | + } |
| 246 | + Document(docId, content, chosenTopics, chosenTopicCounts) |
| 247 | + } |
| 248 | + }.setName(s"LDA-$iter").cache() |
| 249 | + |
| 250 | + if (iter % 20 == 0 && data.context.getCheckpointDir.isDefined) { |
| 251 | + docTopics.checkpoint() |
| 252 | + } |
| 253 | + |
| 254 | + params = collectTopicCounters(docTopics, numTerms, numTopics) |
| 255 | + |
| 256 | + if (iter > burnInIter) { |
| 257 | + probModel.merge(params) |
| 258 | + } |
| 259 | + previousDocTopics.unpersist() |
| 260 | + } |
| 261 | + val burnIn = (totalIter - burnInIter).toDouble |
| 262 | + probModel.topicCounts_ :/= burnIn |
| 263 | + probModel.topicTermCounts_.foreach(_ :/= burnIn) |
| 264 | + (probModel, docTopics) |
| 265 | + } |
| 266 | + |
| 267 | + private def collectTopicCounters(docTopics: RDD[Document], numTerms: Int, numTopics: Int): |
| 268 | + TopicModel = { |
| 269 | + docTopics.mapPartitions { iter => |
| 270 | + val topicCounters = TopicModel(numTopics, numTerms) |
| 271 | + iter.foreach { doc => |
| 272 | + doc.content.zip(doc.topics).foreach(t => topicCounters.update(t._1, t._2, 1)) |
| 273 | + } |
| 274 | + Iterator(topicCounters) |
| 275 | + }.fold(TopicModel(numTopics, numTerms)) { (thatOne, otherOne) => |
| 276 | + thatOne.merge(otherOne) |
| 277 | + } |
| 278 | + } |
| 279 | + |
| 280 | + /** |
| 281 | + * Initial step of Gibbs sampling, which supports incremental LDA. |
| 282 | + */ |
| 283 | + private def sampleTermAssignment(data: RDD[Document], topicModel: TopicModel): |
| 284 | + (TopicModel, RDD[Document]) = { |
| 285 | + val (numTopics, numTerms, alpha, beta) = (topicModel.topicCounts_.size, |
| 286 | + topicModel.topicTermCounts_.head.size, |
| 287 | + topicModel.alpha, topicModel.beta) |
| 288 | + val broadcastParams = data.context.broadcast(topicModel) |
| 289 | + |
| 290 | + val initialDocs = if (topicModel.topicCounts_.norm(2) == 0) { |
| 291 | + data.mapPartitions { docs => |
| 292 | + val rand = new Random |
| 293 | + docs.map { case Document(docId, content, topics, topicDist) => |
| 294 | + val lastDocTopicCount = BSV.zeros[Double](numTopics) |
| 295 | + val lastTopics = content.map { term => |
| 296 | + val topic = uniformDistSampler(rand, numTopics) |
| 297 | + lastDocTopicCount(topic) += 1 |
| 298 | + topic |
| 299 | + } |
| 300 | + Document(docId, content, lastTopics, lastDocTopicCount) |
| 301 | + } |
| 302 | + }.cache() |
| 303 | + } else { |
| 304 | + data.mapPartitions { docs => |
| 305 | + val rand = new Random |
| 306 | + val currentParams = broadcastParams.value |
| 307 | + docs.map { case Document(docId, content, topics, topicDist) => |
| 308 | + val lastDocTopicCount = BSV.zeros[Double](numTopics) |
| 309 | + val lastTopics = content.map { term => |
| 310 | + val dist = currentParams.generateTopicDistForTerm(topicDist, term) |
| 311 | + val lastTopic = multinomialDistSampler(rand, dist) |
| 312 | + lastDocTopicCount(lastTopic) += 1 |
| 313 | + lastTopic |
| 314 | + } |
| 315 | + Document(docId, content, lastTopics, lastDocTopicCount) |
| 316 | + } |
| 317 | + }.cache() |
| 318 | + } |
| 319 | + |
| 320 | + val initialModel = collectTopicCounters(initialDocs, numTerms, numTopics) |
| 321 | + (initialModel, initialDocs) |
| 322 | + } |
| 323 | + |
| 324 | + /** |
| 325 | + * A uniform distribution sampler, which is only used for initialization. |
| 326 | + */ |
| 327 | + def uniformDistSampler(rand: Random, dimension: Int): Int = rand.nextInt(dimension) |
| 328 | + |
| 329 | + /** |
| 330 | + * A multinomial distribution sampler, using roulette method to sample an Int back. |
| 331 | + */ |
| 332 | + def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = { |
| 333 | + val distSum = rand.nextDouble() * breeze.linalg.sum[BDV[Double], Double](dist) |
| 334 | + |
| 335 | + def loop(index: Int, accum: Double): Int = { |
| 336 | + if (index == dist.length) return dist.length - 1 |
| 337 | + val sum = accum - dist(index) |
| 338 | + if (sum <= 0) index else loop(index + 1, sum) |
| 339 | + } |
| 340 | + |
| 341 | + loop(0, distSum) |
| 342 | + } |
| 343 | +} |
0 commit comments