Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 106 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.ml.clustering

import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
Expand Down Expand Up @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging {
extends Model[LDAModel] with LDAParams with Logging with MLWritable {

// NOTE to developers:
// This abstraction should contain all important functionality for basic LDA usage.
Expand Down Expand Up @@ -486,6 +487,64 @@ class LocalLDAModel private[ml] (

@Since("1.6.0")
override def isDistributed: Boolean = false

@Since("1.6.0")
override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
}


@Since("1.6.0")
object LocalLDAModel extends MLReadable[LocalLDAModel] {

private[LocalLDAModel]
class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {

private case class Data(
vocabSize: Int,
topicsMatrix: Matrix,
docConcentration: Vector,
topicConcentration: Double,
gammaShape: Double)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val oldModel = instance.oldLocalModel
val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
oldModel.topicConcentration, oldModel.gammaShape)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class LocalLDAModelReader extends MLReader[LocalLDAModel] {

private val className = classOf[LocalLDAModel].getName

override def load(path: String): LocalLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
val vocabSize = data.getAs[Int](0)
val topicsMatrix = data.getAs[Matrix](1)
val docConcentration = data.getAs[Vector](2)
val topicConcentration = data.getAs[Double](3)
val gammaShape = data.getAs[Double](4)
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
gammaShape)
val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader

@Since("1.6.0")
override def load(path: String): LocalLDAModel = super.load(path)
}


Expand Down Expand Up @@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] (
*/
@Since("1.6.0")
lazy val logPrior: Double = oldDistributedModel.logPrior

@Since("1.6.0")
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
}


@Since("1.6.0")
object DistributedLDAModel extends MLReadable[DistributedLDAModel] {

private[DistributedLDAModel]
class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val modelPath = new Path(path, "oldModel").toString
instance.oldDistributedModel.save(sc, modelPath)
}
}

private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] {

private val className = classOf[DistributedLDAModel].getName

override def load(path: String): DistributedLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(
metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader

@Since("1.6.0")
override def load(path: String): DistributedLDAModel = super.load(path)
}


Expand Down Expand Up @@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
@Experimental
class LDA @Since("1.6.0") (
@Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams {
@Since("1.6.0") override val uid: String)
extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable {

@Since("1.6.0")
def this() = this(Identifiable.randomUID("lda"))
Expand Down Expand Up @@ -695,7 +794,7 @@ class LDA @Since("1.6.0") (
}


private[clustering] object LDA {
private[clustering] object LDA extends DefaultParamsReadable[LDA] {

/** Get dataset for spark.mllib LDA */
def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
Expand All @@ -706,4 +805,7 @@ private[clustering] object LDA {
(docId, features)
}
}

@Since("1.6.0")
override def load(path: String): LDA = super.load(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add Since version

}
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable {
* @param topics Inferred topics (vocabSize x k matrix).
*/
@Since("1.3.0")
class LocalLDAModel private[clustering] (
class LocalLDAModel private[spark] (
@Since("1.3.0") val topics: Matrix,
@Since("1.5.0") override val docConcentration: Vector,
@Since("1.5.0") override val topicConcentration: Double,
override protected[clustering] val gammaShape: Double = 100)
override protected[spark] val gammaShape: Double = 100)
extends LDAModel with Serializable {

@Since("1.3.0")
Expand Down
44 changes: 42 additions & 2 deletions mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


Expand All @@ -39,10 +40,24 @@ object LDASuite {
}.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"k" -> 3,
"maxIter" -> 2,
"checkpointInterval" -> 30,
"learningOffset" -> 1023.0,
"learningDecay" -> 0.52,
"subsamplingRate" -> 0.051
)
}


class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

val k: Int = 5
val vocabSize: Int = 30
Expand Down Expand Up @@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
val lp = model.logPrior
assert(lp <= 0.0 && lp != Double.NegativeInfinity)
}

test("read/write LocalLDAModel") {
def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
assert(model.vocabSize === model2.vocabSize)
assert(Vectors.dense(model.topicsMatrix.toArray) ~==
Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
assert(Vectors.dense(model.getDocConcentration) ~==
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
}

test("read/write DistributedLDAModel") {
def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
assert(model.vocabSize === model2.vocabSize)
assert(Vectors.dense(model.topicsMatrix.toArray) ~==
Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
assert(Vectors.dense(model.getDocConcentration) ~==
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
}