@@ -667,7 +667,7 @@ def predictOnValues(self, dstream):
667667 return dstream .mapValues (lambda x : self ._model .predict (x ))
668668
669669
670- class LDAModel (JavaModelWrapper ):
670+ class LDAModel (JavaModelWrapper , JavaSaveable , Loader ):
671671
672672 """ A clustering model derived from the LDA method.
673673
@@ -690,6 +690,21 @@ class LDAModel(JavaModelWrapper):
690690 >>> model = LDA.train(rdd, k=2)
691691 >>> model.vocabSize()
692692 2
693+ >>> topics = model.describeTopics()
694+ >>> len(topics)
695+ 2
696+ >>> len(list(topics[0])[0])
697+ 2
698+ >>> len(list(topics[0])[1])
699+ 2
700+ >>> topics = model.describeTopics(1)
701+ >>> len(topics)
702+ 2
703+ >>> len(list(topics[0])[0])
704+ 1
705+ >>> len(list(topics[0])[1])
706+ 1
707+
693708 >>> topics = model.topicsMatrix()
694709 >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
695710 >>> assert_almost_equal(topics, topics_expect, 1)
@@ -720,18 +735,27 @@ def vocabSize(self):
720735 """Vocabulary size (number of terms or terms in the vocabulary)"""
721736 return self .call ("vocabSize" )
722737
723- @since ('1.5.0' )
724- def save (self , sc , path ):
725- """Save the LDAModel on to disk.
738+ def describeTopics (self , maxTermsPerTopic = None ):
739+ """Return the topics described by weighted terms.
726740
727- :param sc: SparkContext
728- :param path: str, path to where the model needs to be stored.
741+ WARNING: If vocabSize and k are large, this can return a large object!
729742 """
730- if not isinstance (sc , SparkContext ):
731- raise TypeError ("sc should be a SparkContext, got type %s" % type (sc ))
732- if not isinstance (path , basestring ):
733- raise TypeError ("path should be a basestring, got type %s" % type (path ))
734- self ._java_model .save (sc ._jsc .sc (), path )
743+ if maxTermsPerTopic is None :
744+ topics = self .call ("describeTopics" )
745+ else :
746+ topics = self .call ("describeTopics" , maxTermsPerTopic )
747+
748+ # Converts the result to make the format similar to Scala.
749+ # The returned value is mixed up with topics and topi weights.
750+ converted = []
751+ for elms in [list (elms ) for elms in topics ]:
752+ half_len = int (len (elms ) / 2 )
753+ topics = elms [:half_len ]
754+ topicWeights = elms [(- 1 * half_len ):]
755+ if len (topics ) != len (topicWeights ):
756+ raise TypeError ("Something wrong with a return value: %s" % (topics ))
757+ converted .append ((topics , topicWeights ))
758+ return converted
735759
736760 @classmethod
737761 @since ('1.5.0' )
@@ -745,9 +769,8 @@ def load(cls, sc, path):
745769 raise TypeError ("sc should be a SparkContext, got type %s" % type (sc ))
746770 if not isinstance (path , basestring ):
747771 raise TypeError ("path should be a basestring, got type %s" % type (path ))
748- java_model = sc ._jvm .org .apache .spark .mllib .clustering .DistributedLDAModel .load (
749- sc ._jsc .sc (), path )
750- return cls (java_model )
772+ wrapper_model = callMLlibFunc ("loadLDAModel" , sc , path )
773+ return LDAModel (wrapper_model )
751774
752775
753776class LDA (object ):
@@ -773,10 +796,10 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
773796 :param optimizer: LDAOptimizer used to perform the actual calculation.
774797 Currently "em", "online" are supported. Default to "em".
775798 """
776- model = callMLlibFunc ("trainLDAModel" , rdd , k , maxIterations ,
777- docConcentration , topicConcentration , seed ,
778- checkpointInterval , optimizer )
779- return LDAModel (model )
799+ wrapper_model = callMLlibFunc ("trainLDAModel" , rdd , k , maxIterations ,
800+ docConcentration , topicConcentration , seed ,
801+ checkpointInterval , optimizer )
802+ return LDAModel (wrapper_model )
780803
781804
782805def _test ():
0 commit comments