Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] (
override protected def formatVersion = "1.0"

override def save(sc: SparkContext, path: String): Unit = {
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
gammaShape)
}
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
Expand Down Expand Up @@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
// as a Row in data.
case class Data(topic: Vector, index: Int)

// TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in
// model.predict()
def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
def save(
sc: SparkContext,
path: String,
topicsMatrix: Matrix,
docConcentration: Vector,
topicConcentration: Double,
gammaShape: Double): Unit = {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val k = topicsMatrix.numCols
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
("docConcentration" -> docConcentration.toArray.toSeq) ~
("topicConcentration" -> topicConcentration) ~
("gammaShape" -> gammaShape)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
Expand All @@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): LocalLDAModel = {
def load(
sc: SparkContext,
path: String,
docConcentration: Vector,
topicConcentration: Double,
gammaShape: Double): LocalLDAModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
Expand All @@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
val topicsMat = Matrices.fromBreeze(brzTopics)

// TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
new LocalLDAModel(topicsMat,
Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D)
new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
}
}

Expand All @@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
val docConcentration =
Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
val gammaShape = (metadata \ "gammaShape").extract[Double]
val classNameV1_0 = SaveLoadV1_0.thisClassName

val model = (loadedClassName, loadedVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
case _ => throw new Exception(
s"LocalLDAModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported:\n" +
Expand Down Expand Up @@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {

val thisFormatVersion = "1.0"

val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"

// Store globalTopicTotals as a Vector.
case class Data(globalTopicTotals: Vector)
Expand All @@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
import sqlContext.implicits._

val metadata = compact(render
(("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("k" -> k) ~ ("vocabSize" -> vocabSize) ~
("docConcentration" -> docConcentration.toArray.toSeq) ~
("topicConcentration" -> topicConcentration) ~
Expand Down Expand Up @@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
val gammaShape = (metadata \ "gammaShape").extract[Double]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
val classNameV1_0 = SaveLoadV1_0.thisClassName

val model = (loadedClassName, loadedVersion) match {
case (className, "1.0") if className == classNameV1_0 => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics,
Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
val tempDir1 = Utils.createTempDir()
val path1 = tempDir1.toURI.toString

Expand All @@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
assert(samelocalModel.k === localModel.k)
assert(samelocalModel.vocabSize === localModel.vocabSize)
assert(samelocalModel.docConcentration === localModel.docConcentration)
assert(samelocalModel.topicConcentration === localModel.topicConcentration)
assert(samelocalModel.gammaShape === localModel.gammaShape)

val sameDistributedModel = DistributedLDAModel.load(sc, path2)
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
Expand All @@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)

val graph = distributedModel.graph
Expand Down