Skip to content

Commit 48d5e72

Browse files
committed
Functionality improvement
1 parent 0ad3ac1 commit 48d5e72

File tree

4 files changed

+139
-37
lines changed

4 files changed

+139
-37
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,12 @@ class PythonMLLibAPI extends Serializable {
291291
}
292292

293293
/**
294-
* Java stub for Python mllib Word2Vec fit().
295-
* @param dataBytesJRDD Input
294+
* Java stub for Python mllib Word2Vec fit(). This stub returns a
295+
* handle to the Java object instead of the content of the Java object.
296+
* Extra care needs to be taken in the Python code to ensure it gets freed on
297+
* exit; see the Py4J documentation.
298+
* @param dataBytesJRDD Input JavaRDD
299+
* @return A handle to java Word2VecModel instance at python side
296300
*/
297301
def trainWord2Vec(
298302
dataBytesJRDD: JavaRDD[Array[Byte]]
@@ -304,19 +308,60 @@ class PythonMLLibAPI extends Serializable {
304308
}
305309

306310
/**
307-
* Java stub for Python mllib Word2VecModel
311+
* Java stub for Python mllib Word2VecModel transform
312+
* @param model Word2VecModel instance
313+
* @param word a word
314+
* @return serialized vector representation of word
308315
*/
309-
def Word2VecSynonynms(
316+
def Word2VecModelTransform(
317+
model: Word2VecModel,
318+
word: String
319+
): Array[Byte] = {
320+
SerDe.serializeDoubleVector(model.transform(word))
321+
}
322+
323+
/**
324+
* Java stub for Python mllib Word2VecModel findSynonyms
325+
* @param model Word2VecModel instance
326+
* @param word a word
327+
* @param num number of synonyms to find
328+
* @return a java LinkedList containing serialized version of
329+
* synonyms and similarities
330+
*/
331+
def Word2VecModelSynonyms(
310332
model: Word2VecModel,
311333
word: String,
312334
num: Int
313-
) = {
335+
): java.util.List[java.lang.Object] = {
314336
val result = model.findSynonyms(word, num)
315-
val vec = Vectors.dense(result.map(_._2))
316-
val words = result.map(_._1).toArray
337+
val similarity = Vectors.dense(result.map(_._2))
338+
val words = result.map(_._1)
339+
val ret = new java.util.LinkedList[java.lang.Object]()
340+
ret.add(SerDe.serializeSeqString(words))
341+
ret.add(SerDe.serializeDoubleVector(similarity))
342+
ret
343+
}
344+
345+
/**
346+
* Java stub for Python mllib Word2VecModel findSynonyms
347+
* @param model Word2VecModel instance
348+
* @param vecBytes serialization of vector representation of words
349+
* @param num number of synonyms to find
350+
* @return a java LinkedList containing serialized version of
351+
* synonyms and similarities
352+
*/
353+
def Word2VecModelSynonyms(
354+
model: Word2VecModel,
355+
vecBytes: Array[Byte],
356+
num: Int
357+
): java.util.List[java.lang.Object] = {
358+
val vec = SerDe.deserializeDoubleVector(vecBytes)
359+
val result = model.findSynonyms(vec, num)
360+
val similarity = Vectors.dense(result.map(_._2))
361+
val words = result.map(_._1)
317362
val ret = new java.util.LinkedList[java.lang.Object]()
318363
ret.add(SerDe.serializeSeqString(words))
319-
ret.add(SerDe.serializeDoubleVector(vec))
364+
ret.add(SerDe.serializeDoubleVector(similarity))
320365
ret
321366
}
322367

@@ -713,7 +758,7 @@ private[spark] object SerDe extends Serializable {
713758
}
714759

715760
private[python] def deserializeSeqString(bytes:Array[Byte]):Seq[String] = {
716-
require(bytes.length >=0, "Byte array too short")
761+
require(bytes.length >= 8, "Byte array too short")
717762
val seqLengthBytes = ByteBuffer.wrap(bytes, 0, 8)
718763
seqLengthBytes.order(ByteOrder.nativeOrder())
719764
val ib = seqLengthBytes.asIntBuffer()

mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,11 @@ class PythonMLLibAPISuite extends FunSuite {
7979
val empty2D = SerDe.to2dArray(emptyMatrix)
8080
assert(empty2D === Array[Array[Double]]())
8181
}
82+
83+
test("string seq serialization") {
84+
val original = Array[String]("abc", "def", "ghi")
85+
val bytes = SerDe.serializeSeqString(original)
86+
val ss = SerDe.deserializeSeqString(bytes)
87+
assert(ss === original)
88+
}
8289
}

python/pyspark/mllib/Word2Vec.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
"""
2121

2222
from pyspark.mllib._common import \
23-
_get_unmangled_double_vector_rdd, _get_unmangled_rdd, \
24-
_serialize_double, _deserialize_double_matrix, _deserialize_double_vector, \
23+
_serialize_double_vector, \
24+
_deserialize_double_vector, \
2525
_deserialize_string_seq, \
2626
_get_unmangled_string_seq_rdd
2727

2828
__all__ = ['Word2Vec', 'Word2VecModel']
2929

3030
class Word2VecModel(object):
31-
31+
"""
32+
class for Word2Vec model
33+
"""
3234
def __init__(self, sc, java_model):
3335
"""
3436
:param sc: Spark context
@@ -40,23 +42,38 @@ def __init__(self, sc, java_model):
4042
def __del__(self):
4143
self._sc._gateway.detach(self._java_model)
4244

43-
#def transform(self, word):
45+
def transform(self, word):
46+
pythonAPI = self._sc._jvm.PythonMLLibAPI()
47+
result = pythonAPI.Word2VecModelTransform(self._java_model, word)
48+
return _deserialize_double_vector(result)
4449

45-
#def findSynonyms(self, vector, num):
46-
47-
def findSynonyms(self, word, num):
50+
def findSynonyms(self, x, num):
4851
pythonAPI = self._sc._jvm.PythonMLLibAPI()
49-
result = pythonAPI.Word2VecSynonynms(self._java_model, word, num)
50-
similarity = _deserialize_double_vector(result[1])
52+
if type(x) == str:
53+
result = pythonAPI.Word2VecModelSynonyms(self._java_model, x, num)
54+
else:
55+
xSer = _serialize_double_vector(x)
56+
result = pythonAPI.Word2VecModelSynonyms(self._java_model, xSer, num)
5157
words = _deserialize_string_seq(result[0])
52-
ret = []
53-
for w,s in zip(words, similarity):
54-
ret.append((w,s))
55-
return ret
58+
similarity = _deserialize_double_vector(result[1])
59+
return zip(words, similarity)
5660

5761
class Word2Vec(object):
5862
"""
59-
data:RDD[Array[String]]
63+
Word2Vec creates vector representation of words in a text corpus.
64+
The algorithm first constructs a vocabulary from the corpus
65+
and then learns vector representation of words in the vocabulary.
66+
The vector representation can be used as features in
67+
natural language processing and machine learning algorithms.
68+
69+
We used skip-gram model in our implementation and hierarchical softmax
70+
method to train the model. The variable names in the implementation
71+
matches the original C implementation.
72+
For original C implementation, see https://code.google.com/p/word2vec/
73+
For research papers, see
74+
Efficient Estimation of Word Representations in Vector Space
75+
and
76+
Distributed Representations of Words and Phrases and their Compositionality.
6077
"""
6178
def __init__(self):
6279
self.vectorSize = 100
@@ -81,8 +98,23 @@ def setNumIterations(self, numIterations):
8198
return self
8299

83100
def fit(self, data):
101+
"""
102+
:param data: Input RDD
103+
"""
84104
sc = data.context
85105
dataBytes = _get_unmangled_string_seq_rdd(data)
86106
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(dataBytes._jrdd)
87107
return Word2VecModel(sc, model)
88108

109+
def _test():
110+
import doctest
111+
from pyspark import SparkContext
112+
globs = globals().copy()
113+
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
114+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
115+
globs['sc'].stop()
116+
if failure_count:
117+
exit(-1)
118+
119+
if __name__ == "__main__":
120+
_test()

python/pyspark/mllib/_common.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _serialize_double_vector(v):
144144
"wanted ndarray or SparseVector" % type(v))
145145

146146
def _serialize_string_seq(ss):
147-
"""Serialize a sequence of string"""
147+
"""Serialize a sequence of string."""
148148
seqLength = len(ss)
149149
totalLength = 0
150150
lengthArray = ndarray(shape=[seqLength], dtype=int32)
@@ -200,6 +200,31 @@ def _serialize_sparse_vector(v):
200200
return ba
201201

202202

203+
def _deserialize_string_seq(ba, offset=0):
204+
"""Deserialize a string sequence from a mutually understood format.
205+
>>> import sys
206+
>>> _derserialize_string_seq(_serialize_string_seq(['abc'])) == ['abc']
207+
True
208+
"""
209+
if type(ba) != bytearray:
210+
raise TypeError("__deserialize_string_seq called on a %s; "
211+
"wanted bytearray" % type(ba))
212+
nb = len(ba) - offset
213+
if nb < 8:
214+
raise TypeError("__deserialize_string_seq called on a %d-byte array, "
215+
"which is too short" % nb)
216+
headers = ndarray(shape=[2], buffer=ba, offset=offset, dtype=int32)
217+
seqLength = headers[0]
218+
totalLength = headers[1]
219+
lengthArray = ndarray(shape=[seqLength], buffer=ba, offset=offset + 8, dtype=int32)
220+
offset = offset + 8 + 4 * seqLength
221+
ret = []
222+
for i in range(0, seqLength):
223+
curLen = lengthArray[i]
224+
ret.append(str(ba[offset: offset + curLen]))
225+
offset = offset + curLen
226+
return ret
227+
203228
def _deserialize_double(ba, offset=0):
204229
"""Deserialize a double from a mutually understood format.
205230
@@ -226,19 +251,6 @@ def _deserialize_double(ba, offset=0):
226251
return _unpack("d", ba[offset:])[0]
227252

228253

229-
def _deserialize_string_seq(ba, offset=0):
230-
nb = len(ba) - offset
231-
headers = ndarray(shape=[2], buffer=ba, offset=offset, dtype=int32)
232-
seqLength = headers[0]
233-
totalLength = headers[1]
234-
lengthArray = ndarray(shape=[seqLength], buffer=ba, offset=offset + 8, dtype=int32)
235-
offset = offset + 8 + 4 * seqLength
236-
ret = []
237-
for i in range(0, seqLength):
238-
ret.append(str(ba[offset: offset + lengthArray[i]]))
239-
offset = offset + lengthArray[i]
240-
return ret
241-
242254
def _deserialize_double_vector(ba, offset=0):
243255
"""Deserialize a double vector from a mutually understood format.
244256
@@ -400,6 +412,12 @@ def _get_unmangled_rdd(data, serializer, cache=True):
400412
return dataBytes
401413

402414
def _get_unmangled_string_seq_rdd(data, cache=True):
415+
"""
416+
Map a pickled Python RDD of Python string sequence to a Java RDD of
417+
Array[Byte]
418+
:param cache: If True, the serialized RDD is cached. (default = True)
419+
WARNING: Users should unpersist() this later!
420+
"""
403421
return _get_unmangled_rdd(data, _serialize_string_seq, cache)
404422

405423
def _get_unmangled_double_vector_rdd(data, cache=True):

0 commit comments

Comments
 (0)