Skip to content

Commit 78bbb53

Browse files
committed
use pickle for seq string SerDe
1 parent a264b08 commit 78bbb53

File tree

3 files changed

+24
-33
lines changed

3 files changed

+24
-33
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ import org.apache.spark.mllib.util.MLUtils
4646
import org.apache.spark.rdd.RDD
4747
import org.apache.spark.util.Utils
4848

49-
5049
/**
5150
* :: DeveloperApi ::
5251
* The Java stubs necessary for the Python mllib bindings.
@@ -290,13 +289,13 @@ class PythonMLLibAPI extends Serializable {
290289
* handle to the Java object instead of the content of the Java object.
291290
* Extra care needs to be taken in the Python code to ensure it gets freed on
292291
* exit; see the Py4J documentation.
293-
* @param dataBytesJRDD Input JavaRDD
292+
* @param dataJRDD Input JavaRDD
294293
* @return A handle to java Word2VecModel instance at python side
295294
*/
296295
def trainWord2Vec(
297-
dataBytesJRDD: JavaRDD[Array[Byte]]
296+
dataJRDD: JavaRDD[java.util.ArrayList[String]]
298297
): Word2VecModel = {
299-
val data = dataBytesJRDD.rdd.map(SerDe.deserializeSeqString)
298+
val data = dataJRDD.rdd.map(_.toArray(new Array[String](0)).toSeq).cache()
300299
val word2vec = new Word2Vec()
301300
val model = word2vec.fit(data)
302301
model
@@ -311,8 +310,8 @@ class PythonMLLibAPI extends Serializable {
311310
def Word2VecModelTransform(
312311
model: Word2VecModel,
313312
word: String
314-
): Array[Byte] = {
315-
SerDe.serializeDoubleVector(model.transform(word))
313+
): Vector = {
314+
model.transform(word)
316315
}
317316

318317
/**
@@ -332,8 +331,8 @@ class PythonMLLibAPI extends Serializable {
332331
val similarity = Vectors.dense(result.map(_._2))
333332
val words = result.map(_._1)
334333
val ret = new java.util.LinkedList[java.lang.Object]()
335-
ret.add(SerDe.serializeSeqString(words))
336-
ret.add(SerDe.serializeDoubleVector(similarity))
334+
ret.add(words)
335+
ret.add(similarity)
337336
ret
338337
}
339338

@@ -347,16 +346,15 @@ class PythonMLLibAPI extends Serializable {
347346
*/
348347
def Word2VecModelSynonyms(
349348
model: Word2VecModel,
350-
vecBytes: Array[Byte],
349+
vec: Vector,
351350
num: Int
352351
): java.util.List[java.lang.Object] = {
353-
val vec = SerDe.deserializeDoubleVector(vecBytes)
354352
val result = model.findSynonyms(vec, num)
355353
val similarity = Vectors.dense(result.map(_._2))
356354
val words = result.map(_._1)
357355
val ret = new java.util.LinkedList[java.lang.Object]()
358-
ret.add(SerDe.serializeSeqString(words))
359-
ret.add(SerDe.serializeDoubleVector(similarity))
356+
ret.add(words)
357+
ret.add(similarity)
360358
ret
361359
}
362360

mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,4 @@ class PythonMLLibAPISuite extends FunSuite {
9191
assert(bytes.length / 10 < 25) // 25 bytes per rating
9292

9393
}
94-
95-
test("string seq serialization") {
96-
val original = Array[String]("abc", "def", "ghi")
97-
val bytes = SerDe.serializeSeqString(original)
98-
val ss = SerDe.deserializeSeqString(bytes)
99-
assert(ss === original)
100-
}
10194
}

python/pyspark/mllib/Word2Vec.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
Python package for Word2Vec in MLlib.
2020
"""
2121

22-
from pyspark.mllib._common import \
23-
_serialize_double_vector, \
24-
_deserialize_double_vector, \
25-
_deserialize_string_seq, \
26-
_get_unmangled_string_seq_rdd
22+
from functools import wraps
23+
24+
from pyspark import PickleSerializer
25+
26+
from pyspark.mllib.linalg import _convert_to_vector
2727

2828
__all__ = ['Word2Vec', 'Word2VecModel']
2929

@@ -46,18 +46,19 @@ def __del__(self):
4646
def transform(self, word):
4747
pythonAPI = self._sc._jvm.PythonMLLibAPI()
4848
result = pythonAPI.Word2VecModelTransform(self._java_model, word)
49-
return _deserialize_double_vector(result)
49+
return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
5050

5151
def findSynonyms(self, x, num):
52+
SerDe = self._sc._jvm.SerDe
53+
ser = PickleSerializer()
5254
pythonAPI = self._sc._jvm.PythonMLLibAPI()
5355
if type(x) == str:
54-
result = pythonAPI.Word2VecModelSynonyms(self._java_model, x, num)
56+
jlist = pythonAPI.Word2VecModelSynonyms(self._java_model, x, num)
5557
else:
56-
xSer = _serialize_double_vector(x)
57-
result = pythonAPI.Word2VecModelSynonyms(self._java_model, xSer, num)
58-
words = _deserialize_string_seq(result[0])
59-
similarity = _deserialize_double_vector(result[1])
60-
return zip(words, similarity)
58+
bytes = bytearray(ser.dumps(_convert_to_vector(x)))
59+
vec = self._sc._jvm.SerDe.loads(bytes)
60+
jlist = pythonAPI.Word2VecModelSynonyms(self._java_model, vec, num)
61+
return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(jlist)))
6162

6263

6364
class Word2Vec(object):
@@ -104,8 +105,7 @@ def fit(self, data):
104105
:param data: Input RDD
105106
"""
106107
sc = data.context
107-
dataBytes = _get_unmangled_string_seq_rdd(data)
108-
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(dataBytes._jrdd)
108+
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(data._to_java_object_rdd())
109109
return Word2VecModel(sc, model)
110110

111111

0 commit comments

Comments
 (0)