Skip to content

Commit 2980569

Browse files
committed
[SPARK-6263] Python MLlib API missing items: Utils
1 parent 9a5bbe0 commit 2980569

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ private[python] class PythonMLLibAPI extends Serializable {
7171
minPartitions: Int): JavaRDD[LabeledPoint] =
7272
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
7373

74+
def appendBias(data: org.apache.spark.mllib.linalg.Vector)
75+
= MLUtils.appendBias(data)
76+
77+
def loadVectors(jsc: JavaSparkContext, path: String)
78+
= MLUtils.loadVectors(jsc.sc, path)
79+
7480
private def trainRegressionModel(
7581
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
7682
data: JavaRDD[LabeledPoint],

python/pyspark/mllib/tests.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pyspark.mllib.feature import Word2Vec
4747
from pyspark.mllib.feature import IDF
4848
from pyspark.mllib.feature import StandardScaler
49+
from pyspark.mllib.util import MLUtils
4950
from pyspark.serializers import PickleSerializer
5051
from pyspark.sql import SQLContext
5152

@@ -789,6 +790,29 @@ def test_model_transform(self):
789790
self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
790791

791792

793+
class MLUtilsTests(MLlibTestCase):
794+
def test_append_bias(self):
795+
data = [1.0, 2.0, 3.0]
796+
ret = MLUtils.appendBias(data)
797+
self.assertEqual(ret[3], 1.0)
798+
799+
def test_load_vectors(self):
800+
import shutil
801+
data = [
802+
[1.0, 2.0, 3.0],
803+
[1.0, 2.0, 3.0]
804+
]
805+
try:
806+
self.sc.parallelize(data).saveAsTextFile("test_load_vectors")
807+
ret_rdd = MLUtils.loadVectors(self.sc, "test_load_vectors")
808+
ret = ret_rdd.collect()
809+
self.assertEqual(len(ret), 2)
810+
self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
811+
self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
812+
finally:
813+
shutil.rmtree("test_load_vectors")
814+
815+
792816
if __name__ == "__main__":
793817
if not _have_scipy:
794818
print("NOTE: Skipping SciPy tests as it does not seem to be installed")

python/pyspark/mllib/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ def loadLabeledPoints(sc, path, minPartitions=None):
169169
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
170170
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
171171

172+
@staticmethod
173+
def appendBias(data):
174+
return callMLlibFunc("appendBias", _convert_to_vector(data))
175+
176+
@staticmethod
177+
def loadVectors(sc, path):
178+
return callMLlibFunc("loadVectors", sc, path)
179+
172180

173181
class Saveable(object):
174182
"""

0 commit comments

Comments
 (0)