Skip to content

Commit ce9a475

Browse files
committed
add deserialize_labeled_point to pyspark with tests
1 parent e9fcd49 commit ce9a475

File tree

1 file changed

+51
-21
lines changed

1 file changed

+51
-21
lines changed

python/pyspark/mllib/_common.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pyspark.mllib.linalg import SparseVector
2323
from pyspark.serializers import Serializer
2424

25+
2526
"""
2627
Common utilities shared throughout MLlib, primarily for dealing with
2728
different data types. These include:
@@ -146,7 +147,7 @@ def _serialize_sparse_vector(v):
146147
return ba
147148

148149

149-
def _deserialize_double_vector(ba):
150+
def _deserialize_double_vector(ba, offset=0):
150151
"""Deserialize a double vector from a mutually understood format.
151152
152153
>>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])
@@ -159,43 +160,46 @@ def _deserialize_double_vector(ba):
159160
if type(ba) != bytearray:
160161
raise TypeError("_deserialize_double_vector called on a %s; "
161162
"wanted bytearray" % type(ba))
162-
if len(ba) < 5:
163+
nb = len(ba) - offset
164+
if nb < 5:
163165
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
164-
"which is too short" % len(ba))
165-
if ba[0] == DENSE_VECTOR_MAGIC:
166-
return _deserialize_dense_vector(ba)
167-
elif ba[0] == SPARSE_VECTOR_MAGIC:
168-
return _deserialize_sparse_vector(ba)
166+
"which is too short" % nb)
167+
if ba[offset] == DENSE_VECTOR_MAGIC:
168+
return _deserialize_dense_vector(ba, offset)
169+
elif ba[offset] == SPARSE_VECTOR_MAGIC:
170+
return _deserialize_sparse_vector(ba, offset)
169171
else:
170172
raise TypeError("_deserialize_double_vector called on bytearray "
171173
"with wrong magic")
172174

173175

174-
def _deserialize_dense_vector(ba):
176+
def _deserialize_dense_vector(ba, offset=0):
175177
"""Deserialize a dense vector into a numpy array."""
176-
if len(ba) < 5:
178+
nb = len(ba) - offset
179+
if nb < 5:
177180
raise TypeError("_deserialize_dense_vector called on a %d-byte array, "
178-
"which is too short" % len(ba))
179-
length = ndarray(shape=[1], buffer=ba, offset=1, dtype=int32)[0]
180-
if len(ba) != 8 * length + 5:
181+
"which is too short" % nb)
182+
length = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=int32)[0]
183+
if nb < 8 * length + 5:
181184
raise TypeError("_deserialize_dense_vector called on bytearray "
182185
"with wrong length")
183-
return _deserialize_numpy_array([length], ba, 5)
186+
return _deserialize_numpy_array([length], ba, offset + 5)
184187

185188

186-
def _deserialize_sparse_vector(ba):
189+
def _deserialize_sparse_vector(ba, offset=0):
187190
"""Deserialize a sparse vector into a MLlib SparseVector object."""
188-
if len(ba) < 9:
191+
nb = len(ba) - offset
192+
if nb < 9:
189193
raise TypeError("_deserialize_sparse_vector called on a %d-byte array, "
190-
"which is too short" % len(ba))
191-
header = ndarray(shape=[2], buffer=ba, offset=1, dtype=int32)
194+
"which is too short" % l)
195+
header = ndarray(shape=[2], buffer=ba, offset=offset + 1, dtype=int32)
192196
size = header[0]
193197
nonzeros = header[1]
194-
if len(ba) != 9 + 12 * nonzeros:
198+
if nb < 9 + 12 * nonzeros:
195199
raise TypeError("_deserialize_sparse_vector called on bytearray "
196200
"with wrong length")
197-
indices = _deserialize_numpy_array([nonzeros], ba, 9, dtype=int32)
198-
values = _deserialize_numpy_array([nonzeros], ba, 9 + 4 * nonzeros, dtype=float64)
201+
indices = _deserialize_numpy_array([nonzeros], ba, offset + 9, dtype=int32)
202+
values = _deserialize_numpy_array([nonzeros], ba, offset + 9 + 4 * nonzeros, dtype=float64)
199203
return SparseVector(int(size), indices, values)
200204

201205

@@ -242,7 +246,23 @@ def _deserialize_double_matrix(ba):
242246

243247

244248
def _serialize_labeled_point(p):
245-
"""Serialize a LabeledPoint with a features vector of any type."""
249+
"""
250+
Serialize a LabeledPoint with a features vector of any type.
251+
252+
>>> from pyspark.mllib.regression import LabeledPoint
253+
>>> dp0 = LabeledPoint(0.5, array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]))
254+
>>> dp1 = _deserialize_labeled_point(_serialize_labeled_point(dp0))
255+
>>> dp1.label == dp0.label
256+
True
257+
>>> array_equal(dp1.features, dp0.features)
258+
True
259+
>>> sp0 = LabeledPoint(0.0, SparseVector(4, [1, 3], [3.0, 5.5]))
260+
>>> sp1 = _deserialize_labeled_point(_serialize_labeled_point(sp0))
261+
>>> sp1.label == sp1.label
262+
True
263+
>>> sp1.features == sp0.features
264+
True
265+
"""
246266
from pyspark.mllib.regression import LabeledPoint
247267
serialized_features = _serialize_double_vector(p.features)
248268
header = bytearray(9)
@@ -251,6 +271,16 @@ def _serialize_labeled_point(p):
251271
header_float[0] = p.label
252272
return header + serialized_features
253273

274+
def _deserialize_labeled_point(ba, offset=0):
275+
"""Deserialize a LabeledPoint from a mutually understood format."""
276+
from pyspark.mllib.regression import LabeledPoint
277+
if type(ba) != bytearray:
278+
raise TypeError("Expecting a bytearray but got %s" % type(ba))
279+
if ba[offset] != LABELED_POINT_MAGIC:
280+
raise TypeError("Expecting magic number %d but got %d" % (LABELED_POINT_MAGIC, ba[0]))
281+
label = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=float64)[0]
282+
features = _deserialize_double_vector(ba, offset + 9)
283+
return LabeledPoint(label, features)
254284

255285
def _copyto(array, buffer, offset, shape, dtype):
256286
"""

0 commit comments

Comments
 (0)