Skip to content

Commit c867fdf

Browse files
committed
add Word2Vec to pyspark
1 parent 7db5339 commit c867fdf

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import org.apache.spark.mllib.tree.impurity._
3636
import org.apache.spark.mllib.tree.model.DecisionTreeModel
3737
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
3838
import org.apache.spark.mllib.stat.correlation.CorrelationNames
39+
import org.apache.spark.mllib.feature.Word2Vec
40+
import org.apache.spark.mllib.feature.Word2VecModel
3941
import org.apache.spark.mllib.util.MLUtils
4042
import org.apache.spark.rdd.RDD
4143
import org.apache.spark.util.Utils
@@ -288,6 +290,37 @@ class PythonMLLibAPI extends Serializable {
288290
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
289291
}
290292

293+
/**
294+
* Java stub for Python mllib Word2Vec fit().
295+
* @param dataBytesJRDD Input
296+
*/
297+
def trainWord2Vec(
298+
dataBytesJRDD: JavaRDD[Array[Byte]]
299+
): Word2VecModel = {
300+
val data = dataBytesJRDD.rdd.map(SerDe.deserializeSeqString)
301+
data.collect()
302+
val word2vec = new Word2Vec()
303+
val model = word2vec.fit(data)
304+
model
305+
}
306+
307+
/**
308+
* Java stub for Python mllib Word2VecModel
309+
*/
310+
def Word2VecSynonynms(
311+
model: Word2VecModel,
312+
word: String,
313+
num: Int
314+
) = {
315+
val result = model.findSynonyms(word, num)
316+
val vec = Vectors.dense(result.map(_._2))
317+
val words = result.map(_._1).toArray
318+
val ret = new java.util.LinkedList[java.lang.Object]()
319+
ret.add(SerDe.serializeSeqString(words))
320+
ret.add(SerDe.serializeDoubleVector(vec))
321+
ret
322+
}
323+
291324
/**
292325
* Java stub for Python mllib DecisionTree.train().
293326
* This stub returns a handle to the Java object instead of the content of the Java object.
@@ -659,6 +692,51 @@ private[spark] object SerDe extends Serializable {
659692
bytes
660693
}
661694

695+
private[python] def serializeSeqString(ss:Seq[String]): Array[Byte] = {
696+
val seqLength = ss.length
697+
val lengthArray = new Array[Int](seqLength)
698+
var totalLength = 0
699+
for(s <- ss) {
700+
totalLength += s.length
701+
}
702+
val bytes = new Array[Byte](8 + 4 * seqLength + totalLength)
703+
val bb = ByteBuffer.wrap(bytes)
704+
bb.order(ByteOrder.nativeOrder())
705+
bb.putInt(seqLength)
706+
bb.putInt(totalLength)
707+
for( i <- 0 until seqLength) {
708+
bb.putInt(ss(i).length)
709+
}
710+
for(s <- ss) {
711+
bb.put(s.getBytes())
712+
}
713+
bytes
714+
}
715+
716+
private[python] def deserializeSeqString(bytes:Array[Byte]):Seq[String] = {
717+
require(bytes.length >=0, "Byte array too short")
718+
val seqLengthBytes = ByteBuffer.wrap(bytes, 0, 8)
719+
seqLengthBytes.order(ByteOrder.nativeOrder())
720+
val ib = seqLengthBytes.asIntBuffer()
721+
val seqLength = ib.get()
722+
val totalLength = ib.get()
723+
val lengthBytes = ByteBuffer.wrap(bytes, 8, 4 * seqLength)
724+
lengthBytes.order(ByteOrder.nativeOrder())
725+
val stringBytes = ByteBuffer.wrap(bytes, 8 + 4 * seqLength, totalLength)
726+
stringBytes.order(ByteOrder.nativeOrder())
727+
val ss = new Array[String](seqLength)
728+
val lengthBuffer = lengthBytes.asIntBuffer()
729+
var index = 0
730+
while(lengthBuffer.hasRemaining()){
731+
val curLen = lengthBuffer.get()
732+
val content = new Array[Byte](curLen)
733+
stringBytes.get(content, 0, curLen)
734+
ss(index) = new String(content)
735+
index += 1
736+
}
737+
ss.toSeq
738+
}
739+
662740
private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
663741
val fb = serializeDoubleVector(p.features)
664742
val bytes = new Array[Byte](1 + 8 + fb.length)

python/pyspark/mllib/Word2Vec.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Python package for Word2Vec in MLlib.
20+
"""
21+
22+
from pyspark.mllib._common import \
23+
_get_unmangled_double_vector_rdd, _get_unmangled_rdd, \
24+
_serialize_double, _deserialize_double_matrix, _deserialize_double_vector, \
25+
_deserialize_string_seq, \
26+
_get_unmangled_string_seq_rdd
27+
28+
__all__ = ['Word2Vec', 'Word2VecModel']
29+
30+
class Word2VecModel(object):
31+
32+
def __init__(self, sc, java_model):
33+
"""
34+
:param sc: Spark context
35+
:param java_model: Handle to Java model object
36+
"""
37+
self._sc = sc
38+
self._java_model = java_model
39+
40+
def __del__(self):
41+
self._sc._gateway.detach(self._java_model)
42+
43+
#def transform(self, word):
44+
45+
#def findSynonyms(self, vector, num):
46+
47+
def findSynonyms(self, word, num):
48+
pythonAPI = self._sc._jvm.PythonMLLibAPI()
49+
result = pythonAPI.Word2VecSynonynms(self._java_model, word, num)
50+
similarity = _deserialize_double_vector(result[1])
51+
words = _deserialize_string_seq(result[0])
52+
ret = []
53+
for w,s in zip(words, similarity):
54+
ret.append((w,s))
55+
return ret
56+
57+
class Word2Vec(object):
58+
"""
59+
data:RDD[Array[String]]
60+
"""
61+
def __init__(self):
62+
self.vectorSize = 100
63+
self.startingAlpha = 0.025
64+
self.numPartitions = 1
65+
self.numIterations = 1
66+
67+
def setVectorSize(self, vectorSize):
68+
self.vectorSize = vectorSize
69+
return self
70+
71+
def setLearningRate(self, learningRate):
72+
self.startingAlpha = learningRate
73+
return self
74+
75+
def setNumPartitions(self, numPartitions):
76+
self.numPartitions = numPartitions
77+
return self
78+
79+
def setNumIterations(self, numIterations):
80+
self.numIterations = numIterations
81+
return self
82+
83+
def fit(self, data):
84+
sc = data.context
85+
dataBytes = _get_unmangled_string_seq_rdd(data)
86+
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(dataBytes._jrdd)
87+
return Word2VecModel(sc, model)
88+

python/pyspark/mllib/_common.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,29 @@ def _serialize_double_vector(v):
143143
raise TypeError("_serialize_double_vector called on a %s; "
144144
"wanted ndarray or SparseVector" % type(v))
145145

146+
def _serialize_string_seq(ss):
147+
"""Serialize a sequence of string"""
148+
seqLength = len(ss)
149+
totalLength = 0
150+
lengthArray = ndarray(shape=[seqLength], dtype=int32)
151+
i = 0
152+
for s in ss:
153+
length = len(s)
154+
totalLength = totalLength + length
155+
lengthArray[i] = length
156+
i = i + 1
157+
ba = bytearray(4 + 4 + 4 * seqLength + totalLength)
158+
header_bytes = ndarray(shape=[2], buffer=ba, offset=0, dtype=int32)
159+
header_bytes[0] = seqLength
160+
header_bytes[1] = totalLength
161+
_copyto(lengthArray, buffer=ba, offset=8, shape=[seqLength],dtype=int32)
162+
i = 0
163+
offset = 4 + 4 + 4 * seqLength
164+
for s in ss:
165+
ba[offset:offset + lengthArray[i]] = bytes(s)
166+
offset = offset + lengthArray[i]
167+
i = i + 1
168+
return ba
146169

147170
def _serialize_dense_vector(v):
148171
"""Serialize a dense vector given as a NumPy array."""
@@ -203,6 +226,19 @@ def _deserialize_double(ba, offset=0):
203226
return _unpack("d", ba[offset:])[0]
204227

205228

229+
def _deserialize_string_seq(ba, offset=0):
230+
nb = len(ba) - offset
231+
headers = ndarray(shape=[2], buffer=ba, offset=offset, dtype=int32)
232+
seqLength = headers[0]
233+
totalLength = headers[1]
234+
lengthArray = ndarray(shape=[seqLength], buffer=ba, offset=offset + 8, dtype=int32)
235+
offset = offset + 8 + 4 * seqLength
236+
ret = []
237+
for i in range(0, seqLength):
238+
ret.append(str(ba[offset: offset + lengthArray[i]]))
239+
offset = offset + lengthArray[i]
240+
return ret
241+
206242
def _deserialize_double_vector(ba, offset=0):
207243
"""Deserialize a double vector from a mutually understood format.
208244
@@ -363,6 +399,8 @@ def _get_unmangled_rdd(data, serializer, cache=True):
363399
dataBytes.cache()
364400
return dataBytes
365401

402+
def _get_unmangled_string_seq_rdd(data, cache=True):
403+
return _get_unmangled_rdd(data, _serialize_string_seq, cache)
366404

367405
def _get_unmangled_double_vector_rdd(data, cache=True):
368406
"""

0 commit comments

Comments
 (0)