Skip to content

Commit 68f0653

Browse files
committed
Support some parameters for ALS.train() in Python
1 parent 25ef2ac commit 68f0653

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,11 +506,21 @@ private[python] class PythonMLLibAPI extends Serializable {
506506
* Java stub for Python mllib LDA.run()
507507
*/
508508
def trainLDAModel(
509-
data: JavaRDD[LabeledPoint],
510-
k: Int,
511-
seed: java.lang.Long): LDAModel = {
509+
data: JavaRDD[LabeledPoint],
510+
k: Int,
511+
maxIterations: Int,
512+
docConcentration: Double,
513+
topicConcentration: Double,
514+
seed: java.lang.Long,
515+
checkpointInterval: Int,
516+
optimizer: String): LDAModel = {
512517
val algo = new LDA()
513518
.setK(k)
519+
.setMaxIterations(maxIterations)
520+
.setDocConcentration(docConcentration)
521+
.setTopicConcentration(topicConcentration)
522+
.setCheckpointInterval(checkpointInterval)
523+
.setOptimizer(optimizer)
514524

515525
if (seed != null) algo.setSeed(seed)
516526

python/pyspark/mllib/clustering.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class LDAModel(JavaModelWrapper):
597597
... LabeledPoint(2, [1.0, 0.0]),
598598
... ]
599599
>>> rdd = sc.parallelize(data)
600-
>>> model = LDA.train(rdd, 2)
600+
>>> model = LDA.train(rdd, k=2)
601601
>>> model.vocabSize()
602602
2
603603
>>> topics = model.topicsMatrix()
@@ -625,8 +625,12 @@ def describeTopics(self, maxTermsPerTopic=None):
625625
class LDA():
626626

627627
@classmethod
628-
def train(cls, rdd, k, seed=None):
629-
model = callMLlibFunc("trainLDAModel", rdd, k, seed)
628+
def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
629+
topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
630+
"""Train a LDA model."""
631+
model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
632+
docConcentration, topicConcentration, seed,
633+
checkpointInterval, optimizer)
630634
return LDAModel(model)
631635

632636

0 commit comments

Comments
 (0)