File tree Expand file tree Collapse file tree 2 files changed +10
-7
lines changed
mllib/src/main/scala/org/apache/spark/mllib/api/python Expand file tree Collapse file tree 2 files changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -506,7 +506,7 @@ private[python] class PythonMLLibAPI extends Serializable {
506506 * Java stub for Python mllib LDA.run()
507507 */
508508 def trainLDAModel (
509- data : JavaRDD [LabeledPoint ],
509+ data : JavaRDD [java.util. List [ Any ] ],
510510 k : Int ,
511511 maxIterations : Int ,
512512 docConcentration : Double ,
@@ -524,11 +524,14 @@ private[python] class PythonMLLibAPI extends Serializable {
524524
525525 if (seed != null ) algo.setSeed(seed)
526526
527- try {
528- algo.run(data.rdd.map(x => (x.label.toLong, x.features)))
529- } finally {
530- data.rdd.unpersist(blocking = false )
527+ val documents = data.rdd.map(_.asScala.toArray).map { r =>
528+ r(0 ).getClass.getSimpleName match {
529+ case " Integer" => (r(0 ).asInstanceOf [java.lang.Integer ].toLong, r(1 ).asInstanceOf [Vector ])
530+ case " Long" => (r(0 ).asInstanceOf [java.lang.Long ].toLong, r(1 ).asInstanceOf [Vector ])
531+ case _ => throw new IllegalArgumentException (" input values contains invalid type value." )
532+ }
531533 }
534+ algo.run(documents)
532535 }
533536
534537
Original file line number Diff line number Diff line change @@ -593,8 +593,8 @@ class LDAModel(JavaModelWrapper):
593593 >>> from collections import namedtuple
594594 >>> from numpy.testing import assert_almost_equal
595595 >>> data = [
596- ... LabeledPoint( 1, [0.0, 1.0]),
597- ... LabeledPoint( 2, [1.0, 0.0]) ,
596+ ... [ 1, Vectors.dense( [0.0, 1.0])] ,
597+ ... [ 2, SparseVector(2, {0: 1.0})] ,
598598 ... ]
599599 >>> rdd = sc.parallelize(data)
600600 >>> model = LDA.train(rdd, k=2)
You can’t perform that action at this time.
0 commit comments