Skip to content

Commit 62a9c7e

Browse files
committed
Fix appendBias return type
1 parent 454c73d commit 62a9c7e

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ private[python] class PythonMLLibAPI extends Serializable {
7777
* @param path file or directory path in any Hadoop-supported file system URI
7878
* @return serialized vectors in a RDD
7979
*/
80-
def loadVectors(jsc: JavaSparkContext,
81-
path: String): RDD[Vector] =
80+
def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] =
8281
MLUtils.loadVectors(jsc.sc, path)
8382

8483
private def trainRegressionModel(

python/pyspark/mllib/tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,13 @@ def test_append_bias(self):
824824
data = [2.0, 2.0, 2.0]
825825
ret = MLUtils.appendBias(data)
826826
self.assertEqual(ret[3], 1.0)
827+
self.assertEqual(type(ret), list)
828+
829+
def test_append_bias_with_vector(self):
830+
data = Vectors.dense([2.0, 2.0, 2.0])
831+
ret = MLUtils.appendBias(data)
832+
self.assertEqual(ret[3], 1.0)
833+
self.assertEqual(type(ret), list)
827834

828835
def test_load_vectors(self):
829836
import shutil

python/pyspark/mllib/util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
xrange = range
2424

2525
from pyspark.mllib.common import callMLlibFunc, inherit_doc
26-
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
26+
from pyspark.mllib.linalg import Vector, Vectors, SparseVector, _convert_to_vector
2727

2828

2929
class MLUtils(object):
@@ -172,10 +172,13 @@ def loadLabeledPoints(sc, path, minPartitions=None):
172172
@staticmethod
173173
def appendBias(data):
174174
"""
175-
Returns a new vector with `1.0` (bias) appended to the input vector.
175+
Returns a new vector with `1.0` (bias) appended to
176+
the end of the input vector.
176177
"""
177178
vec = _convert_to_vector(data)
178-
return np.append(vec, 1.0)
179+
if isinstance(vec, Vector):
180+
vec = vec.toArray()
181+
return np.append(vec, 1.0).tolist()
179182

180183
@staticmethod
181184
def loadVectors(sc, path):

0 commit comments

Comments
 (0)