2222from pyspark .mllib .linalg import SparseVector
2323from pyspark .serializers import Serializer
2424
25+
2526"""
2627Common utilities shared throughout MLlib, primarily for dealing with
2728different data types. These include:
@@ -146,7 +147,7 @@ def _serialize_sparse_vector(v):
146147 return ba
147148
148149
149- def _deserialize_double_vector (ba ):
150+ def _deserialize_double_vector (ba , offset = 0 ):
150151 """Deserialize a double vector from a mutually understood format.
151152
152153 >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])
@@ -159,43 +160,46 @@ def _deserialize_double_vector(ba):
159160 if type (ba ) != bytearray :
160161 raise TypeError ("_deserialize_double_vector called on a %s; "
161162 "wanted bytearray" % type (ba ))
162- if len (ba ) < 5 :
163+ nb = len (ba ) - offset
164+ if nb < 5 :
163165 raise TypeError ("_deserialize_double_vector called on a %d-byte array, "
164- "which is too short" % len ( ba ) )
165- if ba [0 ] == DENSE_VECTOR_MAGIC :
166- return _deserialize_dense_vector (ba )
167- elif ba [0 ] == SPARSE_VECTOR_MAGIC :
168- return _deserialize_sparse_vector (ba )
166+ "which is too short" % nb )
167+ if ba [offset ] == DENSE_VECTOR_MAGIC :
168+ return _deserialize_dense_vector (ba , offset )
169+ elif ba [offset ] == SPARSE_VECTOR_MAGIC :
170+ return _deserialize_sparse_vector (ba , offset )
169171 else :
170172 raise TypeError ("_deserialize_double_vector called on bytearray "
171173 "with wrong magic" )
172174
173175
174- def _deserialize_dense_vector (ba ):
176+ def _deserialize_dense_vector (ba , offset = 0 ):
175177 """Deserialize a dense vector into a numpy array."""
176- if len (ba ) < 5 :
178+ nb = len (ba ) - offset
179+ if nb < 5 :
177180 raise TypeError ("_deserialize_dense_vector called on a %d-byte array, "
178- "which is too short" % len ( ba ) )
179- length = ndarray (shape = [1 ], buffer = ba , offset = 1 , dtype = int32 )[0 ]
180- if len ( ba ) != 8 * length + 5 :
181+ "which is too short" % nb )
182+ length = ndarray (shape = [1 ], buffer = ba , offset = offset + 1 , dtype = int32 )[0 ]
183+ if nb < 8 * length + 5 :
181184 raise TypeError ("_deserialize_dense_vector called on bytearray "
182185 "with wrong length" )
183- return _deserialize_numpy_array ([length ], ba , 5 )
186+ return _deserialize_numpy_array ([length ], ba , offset + 5 )
184187
185188
186- def _deserialize_sparse_vector (ba ):
189+ def _deserialize_sparse_vector (ba , offset = 0 ):
187190 """Deserialize a sparse vector into a MLlib SparseVector object."""
188- if len (ba ) < 9 :
191+ nb = len (ba ) - offset
192+ if nb < 9 :
189193 raise TypeError ("_deserialize_sparse_vector called on a %d-byte array, "
190- "which is too short" % len ( ba ) )
191- header = ndarray (shape = [2 ], buffer = ba , offset = 1 , dtype = int32 )
194+ "which is too short" % l )
195+ header = ndarray (shape = [2 ], buffer = ba , offset = offset + 1 , dtype = int32 )
192196 size = header [0 ]
193197 nonzeros = header [1 ]
194- if len ( ba ) != 9 + 12 * nonzeros :
198+ if nb < 9 + 12 * nonzeros :
195199 raise TypeError ("_deserialize_sparse_vector called on bytearray "
196200 "with wrong length" )
197- indices = _deserialize_numpy_array ([nonzeros ], ba , 9 , dtype = int32 )
198- values = _deserialize_numpy_array ([nonzeros ], ba , 9 + 4 * nonzeros , dtype = float64 )
201+ indices = _deserialize_numpy_array ([nonzeros ], ba , offset + 9 , dtype = int32 )
202+ values = _deserialize_numpy_array ([nonzeros ], ba , offset + 9 + 4 * nonzeros , dtype = float64 )
199203 return SparseVector (int (size ), indices , values )
200204
201205
@@ -242,7 +246,23 @@ def _deserialize_double_matrix(ba):
242246
243247
244248def _serialize_labeled_point (p ):
245- """Serialize a LabeledPoint with a features vector of any type."""
249+ """
250+ Serialize a LabeledPoint with a features vector of any type.
251+
252+ >>> from pyspark.mllib.regression import LabeledPoint
253+ >>> dp0 = LabeledPoint(0.5, array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]))
254+ >>> dp1 = _deserialize_labeled_point(_serialize_labeled_point(dp0))
255+ >>> dp1.label == dp0.label
256+ True
257+ >>> array_equal(dp1.features, dp0.features)
258+ True
259+ >>> sp0 = LabeledPoint(0.0, SparseVector(4, [1, 3], [3.0, 5.5]))
260+ >>> sp1 = _deserialize_labeled_point(_serialize_labeled_point(sp0))
261+ >>> sp1.label == sp1.label
262+ True
263+ >>> sp1.features == sp0.features
264+ True
265+ """
246266 from pyspark .mllib .regression import LabeledPoint
247267 serialized_features = _serialize_double_vector (p .features )
248268 header = bytearray (9 )
@@ -251,6 +271,16 @@ def _serialize_labeled_point(p):
251271 header_float [0 ] = p .label
252272 return header + serialized_features
253273
274+ def _deserialize_labeled_point (ba , offset = 0 ):
275+ """Deserialize a LabeledPoint from a mutually understood format."""
276+ from pyspark .mllib .regression import LabeledPoint
277+ if type (ba ) != bytearray :
278+ raise TypeError ("Expecting a bytearray but got %s" % type (ba ))
279+ if ba [offset ] != LABELED_POINT_MAGIC :
280+ raise TypeError ("Expecting magic number %d but got %d" % (LABELED_POINT_MAGIC , ba [0 ]))
281+ label = ndarray (shape = [1 ], buffer = ba , offset = offset + 1 , dtype = float64 )[0 ]
282+ features = _deserialize_double_vector (ba , offset + 9 )
283+ return LabeledPoint (label , features )
254284
255285def _copyto (array , buffer , offset , shape , dtype ):
256286 """
0 commit comments