Skip to content

Commit 77fd1b7

Browse files
committed
Not use LabeledPoint
1 parent 68f0653 commit 77fd1b7

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ private[python] class PythonMLLibAPI extends Serializable {
506506
* Java stub for Python mllib LDA.run()
507507
*/
508508
def trainLDAModel(
509-
data: JavaRDD[LabeledPoint],
509+
data: JavaRDD[java.util.List[Any]],
510510
k: Int,
511511
maxIterations: Int,
512512
docConcentration: Double,
@@ -524,11 +524,14 @@ private[python] class PythonMLLibAPI extends Serializable {
524524

525525
if (seed != null) algo.setSeed(seed)
526526

527-
try {
528-
algo.run(data.rdd.map(x => (x.label.toLong, x.features)))
529-
} finally {
530-
data.rdd.unpersist(blocking = false)
527+
val documents = data.rdd.map(_.asScala.toArray).map { r =>
528+
r(0).getClass.getSimpleName match {
529+
case "Integer" => (r(0).asInstanceOf[java.lang.Integer].toLong, r(1).asInstanceOf[Vector])
530+
case "Long" => (r(0).asInstanceOf[java.lang.Long].toLong, r(1).asInstanceOf[Vector])
531+
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
532+
}
531533
}
534+
algo.run(documents)
532535
}
533536

534537

python/pyspark/mllib/clustering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,8 @@ class LDAModel(JavaModelWrapper):
593593
>>> from collections import namedtuple
594594
>>> from numpy.testing import assert_almost_equal
595595
>>> data = [
596-
... LabeledPoint(1, [0.0, 1.0]),
597-
... LabeledPoint(2, [1.0, 0.0]),
596+
... [1, Vectors.dense([0.0, 1.0])],
597+
... [2, SparseVector(2, {0: 1.0})],
598598
... ]
599599
>>> rdd = sc.parallelize(data)
600600
>>> model = LDA.train(rdd, k=2)

0 commit comments

Comments
 (0)