Skip to content

Commit 154f45d

Browse files
committed
Update docs, name some magic values
1 parent 881fef7 commit 154f45d

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ import org.apache.spark.rdd.RDD
3636
*/
3737
@DeveloperApi
3838
class PythonMLLibAPI extends Serializable {
39+
private val DENSE_VECTOR_MAGIC = 1
40+
private val SPARSE_VECTOR_MAGIC = 2
41+
private val DENSE_MATRIX_MAGIC = 3
42+
3943
private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
4044
val packetLength = bytes.length
4145
if (packetLength < 5) {
@@ -44,7 +48,7 @@ class PythonMLLibAPI extends Serializable {
4448
val bb = ByteBuffer.wrap(bytes)
4549
bb.order(ByteOrder.nativeOrder())
4650
val magic = bb.get()
47-
if (magic != 1) {
51+
if (magic != DENSE_VECTOR_MAGIC) {
4852
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
4953
}
5054
val length = bb.getInt()
@@ -77,7 +81,7 @@ class PythonMLLibAPI extends Serializable {
7781
val bb = ByteBuffer.wrap(bytes)
7882
bb.order(ByteOrder.nativeOrder())
7983
val magic = bb.get()
80-
if (magic != 2) {
84+
if (magic != DENSE_MATRIX_MAGIC) {
8185
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
8286
}
8387
val rows = bb.getInt()

python/pyspark/mllib/_common.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,33 @@
2424

2525
# Dense double vector format:
2626
#
27-
# [8-byte 1] [8-byte length] [length*8 bytes of data]
27+
# [1-byte 1] [4-byte length] [length*8 bytes of data]
2828
#
2929
# Sparse double vector format:
3030
#
31-
# [8-byte 2] [8-byte size] [8-byte entries] [entries*4 bytes of indices] [entries*8 bytes of values]
31+
# [1-byte 2] [4-byte length] [4-byte nonzeros] [nonzeros*4 bytes of indices] [nonzeros*8 bytes of values]
3232
#
3333
# Double matrix format:
3434
#
35-
# [8-byte 3] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
35+
# [1-byte 3] [4-byte rows] [4-byte cols] [rows*cols*8 bytes of data]
3636
#
3737
# This is all in machine-endian. That means that the Java interpreter and the
3838
# Python interpreter must agree on what endian the machine is.
3939

40-
def _deserialize_byte_array(shape, ba, offset):
41-
"""Wrapper around ndarray aliasing hack.
40+
DENSE_VECTOR_MAGIC = 1
41+
SPARSE_VECTOR_MAGIC = 2
42+
DENSE_MATRIX_MAGIC = 3
43+
44+
def _deserialize_numpy_array(shape, ba, offset):
45+
"""
46+
Deserialize a numpy array of float64s from a given offset in
47+
bytearray ba, assigning it the given shape.
4248
4349
>>> x = array([1.0, 2.0, 3.0, 4.0, 5.0])
44-
>>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
50+
>>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0))
4551
True
4652
>>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2)
47-
>>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
53+
>>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0))
4854
True
4955
"""
5056
ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", order='C')
@@ -71,7 +77,7 @@ def _serialize_double_vector(v):
7177
v = v.astype('float64')
7278
length = v.shape[0]
7379
ba = bytearray(5 + 8 * length)
74-
ba[0] = 1
80+
ba[0] = DENSE_VECTOR_MAGIC
7581
length_bytes = ndarray(shape=[1], buffer=ba, offset=1, dtype="int32")
7682
length_bytes[0] = length
7783
arr_mid = ndarray(shape=[length], buffer=ba, offset=5, dtype="float64")
@@ -91,14 +97,14 @@ def _deserialize_double_vector(ba):
9197
if len(ba) < 5:
9298
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
9399
"which is too short" % len(ba))
94-
if ba[0] != 1:
100+
if ba[0] != DENSE_VECTOR_MAGIC:
95101
raise TypeError("_deserialize_double_vector called on bytearray "
96102
"with wrong magic")
97103
length = ndarray(shape=[1], buffer=ba, offset=1, dtype="int32")[0]
98104
if len(ba) != 8*length + 5:
99105
raise TypeError("_deserialize_double_vector called on bytearray "
100106
"with wrong length")
101-
return _deserialize_byte_array([length], ba, 5)
107+
return _deserialize_numpy_array([length], ba, 5)
102108

103109
def _serialize_double_matrix(m):
104110
"""Serialize a double matrix into a mutually understood format."""
@@ -111,7 +117,7 @@ def _serialize_double_matrix(m):
111117
rows = m.shape[0]
112118
cols = m.shape[1]
113119
ba = bytearray(9 + 8 * rows * cols)
114-
ba[0] = 2
120+
ba[0] = DENSE_MATRIX_MAGIC
115121
lengths = ndarray(shape=[3], buffer=ba, offset=1, dtype="int32")
116122
lengths[0] = rows
117123
lengths[1] = cols
@@ -130,7 +136,7 @@ def _deserialize_double_matrix(ba):
130136
if len(ba) < 9:
131137
raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
132138
"which is too short" % len(ba))
133-
if ba[0] != 2:
139+
if ba[0] != DENSE_MATRIX_MAGIC:
134140
raise TypeError("_deserialize_double_matrix called on bytearray "
135141
"with wrong magic")
136142
lengths = ndarray(shape=[2], buffer=ba, offset=1, dtype="int32")
@@ -139,7 +145,7 @@ def _deserialize_double_matrix(ba):
139145
if (len(ba) != 8 * rows * cols + 9):
140146
raise TypeError("_deserialize_double_matrix called on bytearray "
141147
"with wrong length")
142-
return _deserialize_byte_array([rows, cols], ba, 9)
148+
return _deserialize_numpy_array([rows, cols], ba, 9)
143149

144150
def _linear_predictor_typecheck(x, coeffs):
145151
"""Check that x is a one-dimensional vector of the right shape.

0 commit comments

Comments
 (0)