Skip to content

Commit 4cab972

Browse files
author
Feynman Liang
committed
Default gammaShape
1 parent 0f3366a commit 4cab972

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.json4s.jackson.JsonMethods._
2727
import org.apache.spark.SparkContext
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.api.java.JavaPairRDD
30-
import org.apache.spark.broadcast.Broadcast
3130
import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
3231
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
3332
import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -190,7 +189,8 @@ class LocalLDAModel private[clustering] (
190189
val topics: Matrix,
191190
override val docConcentration: Vector,
192191
override val topicConcentration: Double,
193-
override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
192+
override protected[clustering] val gammaShape: Double = 100
193+
) extends LDAModel with Serializable {
194194

195195
override def k: Int = topics.numCols
196196

@@ -455,8 +455,9 @@ class DistributedLDAModel private[clustering] (
455455
val vocabSize: Int,
456456
override val docConcentration: Vector,
457457
override val topicConcentration: Double,
458-
override protected[clustering] val gammaShape: Double,
459-
private[spark] val iterationTimes: Array[Double]) extends LDAModel {
458+
private[spark] val iterationTimes: Array[Double],
459+
override protected[clustering] val gammaShape: Double = 100
460+
) extends LDAModel {
460461

461462
import LDA._
462463

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
209209
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
210210
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
211211
this.graphCheckpointer.deleteAllCheckpoints()
212-
// This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
213-
// conversion
212+
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
213+
// LDAModel.toLocal conversion
214214
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
215215
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
216-
100, iterationTimes)
216+
iterationTimes)
217217
}
218218
}
219219

0 commit comments

Comments
 (0)