Skip to content

Commit 58abcab

Browse files
committed
save load for ML localLDA
1 parent ff442bb commit 58abcab

File tree

3 files changed

+130
-7
lines changed

3 files changed

+130
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import org.apache.hadoop.fs.Path
2021
import org.apache.spark.Logging
2122
import org.apache.spark.annotation.{Experimental, Since}
22-
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
2323
import org.apache.spark.ml.{Estimator, Model}
2424
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter}
2525
import org.apache.spark.ml.param._
26+
import org.apache.spark.ml.util._
2627
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
2728
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
2829
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
@@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
322323
@Since("1.6.0") override val uid: String,
323324
@Since("1.6.0") val vocabSize: Int,
324325
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
325-
extends Model[LDAModel] with LDAParams with Logging {
326+
extends Model[LDAModel] with LDAParams with Logging with MLWritable {
326327

327328
// NOTE to developers:
328329
// This abstraction should contain all important functionality for basic LDA usage.
@@ -486,6 +487,61 @@ class LocalLDAModel private[ml] (
486487

487488
@Since("1.6.0")
488489
override def isDistributed: Boolean = false
490+
491+
@Since("1.6.0")
492+
override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
493+
}
494+
495+
496+
@Since("1.6.0")
497+
object LocalLDAModel extends MLReadable[LocalLDAModel] {
498+
499+
private[LocalLDAModel]
500+
class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {
501+
502+
private case class Data(vocabSize: Int,
503+
topicsMatrix: Matrix,
504+
docConcentration: Vector,
505+
topicConcentration: Double)
506+
507+
override protected def saveImpl(path: String): Unit = {
508+
DefaultParamsWriter.saveMetadata(instance, path, sc)
509+
val oldModel = instance.oldLocalModel
510+
val data = Data(instance.vocabSize,
511+
oldModel.topicsMatrix,
512+
oldModel.docConcentration,
513+
oldModel.topicConcentration)
514+
val dataPath = new Path(path, "data").toString
515+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
516+
}
517+
}
518+
519+
private class LocalLDAModelReader extends MLReader[LocalLDAModel] {
520+
521+
private val className = classOf[LocalLDAModel].getName
522+
523+
override def load(path: String): LocalLDAModel = {
524+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
525+
val dataPath = new Path(path, "data").toString
526+
val data = sqlContext.read.parquet(dataPath)
527+
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration")
528+
.head()
529+
val vocabSize = data.getAs[Int](0)
530+
val topicsMatrix = data.getAs[Matrix](1)
531+
val docConcentration = data.getAs[Vector](2)
532+
val topicConcentration = data.getAs[Double](3)
533+
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration)
534+
val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
535+
DefaultParamsReader.getAndSetParams(model, metadata)
536+
model
537+
}
538+
}
539+
540+
@Since("1.6.0")
541+
override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader
542+
543+
@Since("1.6.0")
544+
override def load(path: String): LocalLDAModel = super.load(path)
489545
}
490546

491547

@@ -562,6 +618,43 @@ class DistributedLDAModel private[ml] (
562618
*/
563619
@Since("1.6.0")
564620
lazy val logPrior: Double = oldDistributedModel.logPrior
621+
622+
@Since("1.6.0")
623+
override def write: MLWriter = new DistributedLDAModel.DistributedLDAModelWriter(this)
624+
}
625+
626+
627+
@Since("1.6.0")
628+
object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
629+
630+
private[DistributedLDAModel]
631+
class DistributedLDAModelWriter(instance: DistributedLDAModel) extends MLWriter {
632+
633+
override protected def saveImpl(path: String): Unit = {
634+
DefaultParamsWriter.saveMetadata(instance, path, sc)
635+
instance.oldDistributedModel.save(sc, path)
636+
}
637+
}
638+
639+
private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] {
640+
641+
private val className = classOf[DistributedLDAModel].getName
642+
643+
override def load(path: String): DistributedLDAModel = {
644+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
645+
val oldModel = OldDistributedLDAModel.load(sc, path)
646+
val model = new DistributedLDAModel(
647+
metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
648+
DefaultParamsReader.getAndSetParams(model, metadata)
649+
model
650+
}
651+
}
652+
653+
@Since("1.6.0")
654+
override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader
655+
656+
@Since("1.6.0")
657+
override def load(path: String): DistributedLDAModel = super.load(path)
565658
}
566659

567660

@@ -593,7 +686,8 @@ class DistributedLDAModel private[ml] (
593686
@Since("1.6.0")
594687
@Experimental
595688
class LDA @Since("1.6.0") (
596-
@Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams {
689+
@Since("1.6.0") override val uid: String) extends Estimator[LDAModel]
690+
with LDAParams with DefaultParamsWritable {
597691

598692
@Since("1.6.0")
599693
def this() = this(Identifiable.randomUID("lda"))
@@ -695,7 +789,7 @@ class LDA @Since("1.6.0") (
695789
}
696790

697791

698-
private[clustering] object LDA {
792+
private[clustering] object LDA extends DefaultParamsReadable[LDA]{
699793

700794
/** Get dataset for spark.mllib LDA */
701795
def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
@@ -706,4 +800,6 @@ private[clustering] object LDA {
706800
(docId, features)
707801
}
708802
}
803+
804+
override def load(path: String): LDA = super.load(path)
709805
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ abstract class LDAModel private[clustering] extends Saveable {
187187
* @param topics Inferred topics (vocabSize x k matrix).
188188
*/
189189
@Since("1.3.0")
190-
class LocalLDAModel private[clustering] (
190+
class LocalLDAModel private[spark] (
191191
@Since("1.3.0") val topics: Matrix,
192192
@Since("1.5.0") override val docConcentration: Vector,
193193
@Since("1.5.0") override val topicConcentration: Double,

mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
package org.apache.spark.ml.clustering
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.ml.util.MLTestingUtils
21+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2222
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.mllib.util.TestingUtils._
2425
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2526

2627

@@ -39,10 +40,24 @@ object LDASuite {
3940
}.map(v => new TestRow(v))
4041
sql.createDataFrame(rdd)
4142
}
43+
44+
/**
45+
* Mapping from all Params to valid settings which differ from the defaults.
46+
* This is useful for tests which need to exercise all Params, such as save/load.
47+
* This excludes input columns to simplify some tests.
48+
*/
49+
val allParamSettings: Map[String, Any] = Map(
50+
"k" -> 3,
51+
"maxIter" -> 10,
52+
"checkpointInterval" -> 30,
53+
"learningOffset" -> 1023.0,
54+
"learningDecay" -> 0.52,
55+
"subsamplingRate" -> 0.051
56+
)
4257
}
4358

4459

45-
class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
60+
class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
4661

4762
val k: Int = 5
4863
val vocabSize: Int = 30
@@ -218,4 +233,16 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
218233
val lp = model.logPrior
219234
assert(lp <= 0.0 && lp != Double.NegativeInfinity)
220235
}
236+
237+
test("read/write LocalLDAModel") {
238+
def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
239+
assert(model.vocabSize === model2.vocabSize)
240+
assert(Vectors.dense(model.topicsMatrix.toArray) ~==
241+
Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
242+
assert(Vectors.dense(model.getDocConcentration) ~==
243+
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
244+
}
245+
val lda = new LDA()
246+
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
247+
}
221248
}

0 commit comments

Comments
 (0)