Skip to content

Commit 5d555b1

Browse files
committed
Construct scipy.sparse matrix
1 parent c345a44 commit 5d555b1

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

python/pyspark/mllib/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import array as pyarray
2626

2727
from numpy import array, array_equal, zeros, inf
28+
import scipy.sparse as sp
2829
from py4j.protocol import Py4JJavaError
2930

3031
if sys.version_info[:2] <= (2, 6):
@@ -832,6 +833,14 @@ def test_append_bias_with_vector(self):
832833
self.assertEqual(ret[3], 1.0)
833834
self.assertEqual(type(ret), list)
834835

836+
def test_append_bias_with_sp_vector(self):
837+
data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
838+
# Returned value must be scipy.sparse matrix
839+
ret = MLUtils.appendBias(data)
840+
self.assertEqual(ret.shape, (1, 4))
841+
self.assertEqual(ret.toarray()[0][3], 1.0)
842+
self.assertEqual(type(ret), sp.csc_matrix)
843+
835844
def test_load_vectors(self):
836845
import shutil
837846
data = [

python/pyspark/mllib/util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717

1818
import sys
1919
import numpy as np
20+
import scipy.sparse as sp
2021
import warnings
2122

2223
if sys.version > '3':
2324
xrange = range
2425

2526
from pyspark.mllib.common import callMLlibFunc, inherit_doc
26-
from pyspark.mllib.linalg import Vector, Vectors, SparseVector, _convert_to_vector
27+
from pyspark.mllib.linalg import Vector, Vectors, DenseVector, SparseVector, _convert_to_vector
2728

2829

2930
class MLUtils(object):
@@ -176,7 +177,9 @@ def appendBias(data):
176177
the end of the input vector.
177178
"""
178179
vec = _convert_to_vector(data)
179-
if isinstance(vec, Vector):
180+
if isinstance(vec, SparseVector):
181+
return sp.csc_matrix(np.append(vec.toArray(), 1.0))
182+
elif isinstance(vec, Vector):
180183
vec = vec.toArray()
181184
return np.append(vec, 1.0).tolist()
182185

0 commit comments

Comments
 (0)