1717
1818package org .apache .spark .ml .clustering
1919
20+ import org .apache .hadoop .fs .Path
2021import org .apache .spark .Logging
2122import org .apache .spark .annotation .{Experimental , Since }
22- import org .apache .spark .ml .util .{SchemaUtils , Identifiable }
2323import org .apache .spark .ml .{Estimator , Model }
2424import org .apache .spark .ml .param .shared .{HasCheckpointInterval , HasFeaturesCol , HasSeed , HasMaxIter }
2525import org .apache .spark .ml .param ._
26+ import org .apache .spark .ml .util ._
2627import org .apache .spark .mllib .clustering .{DistributedLDAModel => OldDistributedLDAModel ,
2728 EMLDAOptimizer => OldEMLDAOptimizer , LDA => OldLDA , LDAModel => OldLDAModel ,
2829 LDAOptimizer => OldLDAOptimizer , LocalLDAModel => OldLocalLDAModel ,
@@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
322323 @ Since (" 1.6.0" ) override val uid : String ,
323324 @ Since (" 1.6.0" ) val vocabSize : Int ,
324325 @ Since (" 1.6.0" ) @ transient protected val sqlContext : SQLContext )
325- extends Model [LDAModel ] with LDAParams with Logging {
326+ extends Model [LDAModel ] with LDAParams with Logging with MLWritable {
326327
327328 // NOTE to developers:
328329 // This abstraction should contain all important functionality for basic LDA usage.
@@ -486,6 +487,61 @@ class LocalLDAModel private[ml] (
486487
487488 @ Since (" 1.6.0" )
488489 override def isDistributed : Boolean = false
490+
491+ @ Since (" 1.6.0" )
492+ override def write : MLWriter = new LocalLDAModel .LocalLDAModelWriter (this )
493+ }
494+
495+
496+ @ Since (" 1.6.0" )
497+ object LocalLDAModel extends MLReadable [LocalLDAModel ] {
498+
499+ private [LocalLDAModel ]
500+ class LocalLDAModelWriter (instance : LocalLDAModel ) extends MLWriter {
501+
502+ private case class Data (vocabSize : Int ,
503+ topicsMatrix : Matrix ,
504+ docConcentration : Vector ,
505+ topicConcentration : Double )
506+
507+ override protected def saveImpl (path : String ): Unit = {
508+ DefaultParamsWriter .saveMetadata(instance, path, sc)
509+ val oldModel = instance.oldLocalModel
510+ val data = Data (instance.vocabSize,
511+ oldModel.topicsMatrix,
512+ oldModel.docConcentration,
513+ oldModel.topicConcentration)
514+ val dataPath = new Path (path, " data" ).toString
515+ sqlContext.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
516+ }
517+ }
518+
519+ private class LocalLDAModelReader extends MLReader [LocalLDAModel ] {
520+
521+ private val className = classOf [LocalLDAModel ].getName
522+
523+ override def load (path : String ): LocalLDAModel = {
524+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
525+ val dataPath = new Path (path, " data" ).toString
526+ val data = sqlContext.read.parquet(dataPath)
527+ .select(" vocabSize" , " topicsMatrix" , " docConcentration" , " topicConcentration" )
528+ .head()
529+ val vocabSize = data.getAs[Int ](0 )
530+ val topicsMatrix = data.getAs[Matrix ](1 )
531+ val docConcentration = data.getAs[Vector ](2 )
532+ val topicConcentration = data.getAs[Double ](3 )
533+ val oldModel = new OldLocalLDAModel (topicsMatrix, docConcentration, topicConcentration)
534+ val model = new LocalLDAModel (metadata.uid, vocabSize, oldModel, sqlContext)
535+ DefaultParamsReader .getAndSetParams(model, metadata)
536+ model
537+ }
538+ }
539+
540+ @ Since (" 1.6.0" )
541+ override def read : MLReader [LocalLDAModel ] = new LocalLDAModelReader
542+
543+ @ Since (" 1.6.0" )
544+ override def load (path : String ): LocalLDAModel = super .load(path)
489545}
490546
491547
@@ -562,6 +618,43 @@ class DistributedLDAModel private[ml] (
562618 */
563619 @ Since (" 1.6.0" )
564620 lazy val logPrior : Double = oldDistributedModel.logPrior
621+
622+ @ Since (" 1.6.0" )
623+ override def write : MLWriter = new DistributedLDAModel .DistributedLDAModelWriter (this )
624+ }
625+
626+
627+ @ Since (" 1.6.0" )
628+ object DistributedLDAModel extends MLReadable [DistributedLDAModel ] {
629+
630+ private [DistributedLDAModel ]
631+ class DistributedLDAModelWriter (instance : DistributedLDAModel ) extends MLWriter {
632+
633+ override protected def saveImpl (path : String ): Unit = {
634+ DefaultParamsWriter .saveMetadata(instance, path, sc)
635+ instance.oldDistributedModel.save(sc, path)
636+ }
637+ }
638+
639+ private class DistributedLDAModelReader extends MLReader [DistributedLDAModel ] {
640+
641+ private val className = classOf [DistributedLDAModel ].getName
642+
643+ override def load (path : String ): DistributedLDAModel = {
644+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
645+ val oldModel = OldDistributedLDAModel .load(sc, path)
646+ val model = new DistributedLDAModel (
647+ metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None )
648+ DefaultParamsReader .getAndSetParams(model, metadata)
649+ model
650+ }
651+ }
652+
653+ @ Since (" 1.6.0" )
654+ override def read : MLReader [DistributedLDAModel ] = new DistributedLDAModelReader
655+
656+ @ Since (" 1.6.0" )
657+ override def load (path : String ): DistributedLDAModel = super .load(path)
565658}
566659
567660
@@ -593,7 +686,8 @@ class DistributedLDAModel private[ml] (
593686@ Since (" 1.6.0" )
594687@ Experimental
595688class LDA @ Since (" 1.6.0" ) (
596- @ Since (" 1.6.0" ) override val uid : String ) extends Estimator [LDAModel ] with LDAParams {
689+ @ Since (" 1.6.0" ) override val uid : String ) extends Estimator [LDAModel ]
690+ with LDAParams with DefaultParamsWritable {
597691
598692 @ Since (" 1.6.0" )
599693 def this () = this (Identifiable .randomUID(" lda" ))
@@ -695,7 +789,7 @@ class LDA @Since("1.6.0") (
695789}
696790
697791
698- private [clustering] object LDA {
792+ private [clustering] object LDA extends DefaultParamsReadable [ LDA ] {
699793
700794 /** Get dataset for spark.mllib LDA */
701795 def getOldDataset (dataset : DataFrame , featuresCol : String ): RDD [(Long , Vector )] = {
@@ -706,4 +800,6 @@ private[clustering] object LDA {
706800 (docId, features)
707801 }
708802 }
803+
804+ override def load (path : String ): LDA = super .load(path)
709805}
0 commit comments