Skip to content

Commit 098c734

Browse files
Ishiiharamengxr
authored andcommitted
[SPARK-3486][MLlib][PySpark] PySpark support for Word2Vec
mengxr Added PySpark support for Word2Vec Change list (1) PySpark support for Word2Vec (2) SerDe support of string sequence both on python side and JVM side (3) Test for SerDe of string sequence on JVM side Author: Liquan Pei <[email protected]> Closes apache#2356 from Ishiihara/Word2Vec-python and squashes the following commits: 476ea34 [Liquan Pei] style fixes b13a0b9 [Liquan Pei] resolve merge conflicts and minor fixes 8671eba [Liquan Pei] Merge remote-tracking branch 'upstream/master' into Word2Vec-python daf88a6 [Liquan Pei] modification according to feedback a73fa19 [Liquan Pei] clean up 3d8007b [Liquan Pei] fix findSynonyms for vector 1bdcd2e [Liquan Pei] minor fixes cdef9f4 [Liquan Pei] add missing comments b7447eb [Liquan Pei] modify according to feedback b9a7383 [Liquan Pei] cache words RDD in fit 89490bf [Liquan Pei] add tests and Word2VecModelWrapper 78bbb53 [Liquan Pei] use pickle for seq string SerDe a264b08 [Liquan Pei] Merge remote-tracking branch 'upstream/master' into Word2Vec-python ca1e5ff [Liquan Pei] fix test 68e7276 [Liquan Pei] minor style fixes 48d5e72 [Liquan Pei] Functionality improvement 0ad3ac1 [Liquan Pei] minor fix c867fdf [Liquan Pei] add Word2Vec to pyspark
1 parent 3d7b36e commit 098c734

File tree

5 files changed

+264
-7
lines changed

5 files changed

+264
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi
2929
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
3030
import org.apache.spark.mllib.classification._
3131
import org.apache.spark.mllib.clustering._
32+
import org.apache.spark.mllib.feature.Word2Vec
33+
import org.apache.spark.mllib.feature.Word2VecModel
3234
import org.apache.spark.mllib.optimization._
3335
import org.apache.spark.mllib.linalg._
3436
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
4244
import org.apache.spark.mllib.stat.correlation.CorrelationNames
4345
import org.apache.spark.mllib.util.MLUtils
4446
import org.apache.spark.rdd.RDD
47+
import org.apache.spark.storage.StorageLevel
4548
import org.apache.spark.util.Utils
4649

47-
4850
/**
4951
* :: DeveloperApi ::
5052
* The Java stubs necessary for the Python mllib bindings.
@@ -287,6 +289,59 @@ class PythonMLLibAPI extends Serializable {
287289
ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
288290
}
289291

292+
/**
293+
* Java stub for Python mllib Word2Vec fit(). This stub returns a
294+
* handle to the Java object instead of the content of the Java object.
295+
* Extra care needs to be taken in the Python code to ensure it gets freed on
296+
* exit; see the Py4J documentation.
297+
* @param dataJRDD input JavaRDD
298+
* @param vectorSize size of vector
299+
* @param learningRate initial learning rate
300+
* @param numPartitions number of partitions
301+
* @param numIterations number of iterations
302+
* @param seed initial seed for random generator
303+
* @return A handle to java Word2VecModelWrapper instance at python side
304+
*/
305+
def trainWord2Vec(
306+
dataJRDD: JavaRDD[java.util.ArrayList[String]],
307+
vectorSize: Int,
308+
learningRate: Double,
309+
numPartitions: Int,
310+
numIterations: Int,
311+
seed: Long): Word2VecModelWrapper = {
312+
val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
313+
val word2vec = new Word2Vec()
314+
.setVectorSize(vectorSize)
315+
.setLearningRate(learningRate)
316+
.setNumPartitions(numPartitions)
317+
.setNumIterations(numIterations)
318+
.setSeed(seed)
319+
val model = word2vec.fit(data)
320+
data.unpersist()
321+
new Word2VecModelWrapper(model)
322+
}
323+
324+
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
325+
def transform(word: String): Vector = {
326+
model.transform(word)
327+
}
328+
329+
def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
330+
val vec = transform(word)
331+
findSynonyms(vec, num)
332+
}
333+
334+
def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
335+
val result = model.findSynonyms(vector, num)
336+
val similarity = Vectors.dense(result.map(_._2))
337+
val words = result.map(_._1)
338+
val ret = new java.util.LinkedList[java.lang.Object]()
339+
ret.add(words)
340+
ret.add(similarity)
341+
ret
342+
}
343+
}
344+
290345
/**
291346
* Java stub for Python mllib DecisionTree.train().
292347
* This stub returns a handle to the Java object instead of the content of the Java object.

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ private case class VocabWord(
6767
class Word2Vec extends Serializable with Logging {
6868

6969
private var vectorSize = 100
70-
private var startingAlpha = 0.025
70+
private var learningRate = 0.025
7171
private var numPartitions = 1
7272
private var numIterations = 1
7373
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
8484
* Sets initial learning rate (default: 0.025).
8585
*/
8686
def setLearningRate(learningRate: Double): this.type = {
87-
this.startingAlpha = learningRate
87+
this.learningRate = learningRate
8888
this
8989
}
9090

@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
286286
val syn0Global =
287287
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
288288
val syn1Global = new Array[Float](vocabSize * vectorSize)
289-
var alpha = startingAlpha
289+
var alpha = learningRate
290290
for (k <- 1 to numIterations) {
291291
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
292292
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
300300
lwc = wordCount
301301
// TODO: discount by iteration?
302302
alpha =
303-
startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
304-
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
303+
learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
304+
if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
305305
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
306306
}
307307
wc += sentence.size
@@ -437,7 +437,7 @@ class Word2VecModel private[mllib] (
437437
* Find synonyms of a word
438438
* @param word a word
439439
* @param num number of synonyms to find
440-
* @return array of (word, similarity)
440+
* @return array of (word, cosineSimilarity)
441441
*/
442442
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
443443
val vector = transform(word)

python/docs/pyspark.mllib.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ pyspark.mllib.clustering module
2020
:undoc-members:
2121
:show-inheritance:
2222

23+
pyspark.mllib.feature module
24+
-------------------------------
25+
26+
.. automodule:: pyspark.mllib.feature
27+
:members:
28+
:undoc-members:
29+
:show-inheritance:
30+
2331
pyspark.mllib.linalg module
2432
---------------------------
2533

python/pyspark/mllib/feature.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Python package for feature in MLlib.
20+
"""
21+
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
22+
23+
from pyspark.mllib.linalg import _convert_to_vector
24+
25+
__all__ = ['Word2Vec', 'Word2VecModel']
26+
27+
28+
class Word2VecModel(object):
29+
"""
30+
class for Word2Vec model
31+
"""
32+
def __init__(self, sc, java_model):
33+
"""
34+
:param sc: Spark context
35+
:param java_model: Handle to Java model object
36+
"""
37+
self._sc = sc
38+
self._java_model = java_model
39+
40+
def __del__(self):
41+
self._sc._gateway.detach(self._java_model)
42+
43+
def transform(self, word):
44+
"""
45+
:param word: a word
46+
:return: vector representation of word
47+
Transforms a word to its vector representation
48+
49+
Note: local use only
50+
"""
51+
# TODO: make transform usable in RDD operations from python side
52+
result = self._java_model.transform(word)
53+
return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
54+
55+
def findSynonyms(self, x, num):
56+
"""
57+
:param x: a word or a vector representation of word
58+
:param num: number of synonyms to find
59+
:return: array of (word, cosineSimilarity)
60+
Find synonyms of a word
61+
62+
Note: local use only
63+
"""
64+
# TODO: make findSynonyms usable in RDD operations from python side
65+
ser = PickleSerializer()
66+
if type(x) == str:
67+
jlist = self._java_model.findSynonyms(x, num)
68+
else:
69+
bytes = bytearray(ser.dumps(_convert_to_vector(x)))
70+
vec = self._sc._jvm.SerDe.loads(bytes)
71+
jlist = self._java_model.findSynonyms(vec, num)
72+
words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist)))
73+
return zip(words, similarity)
74+
75+
76+
class Word2Vec(object):
77+
"""
78+
Word2Vec creates vector representation of words in a text corpus.
79+
The algorithm first constructs a vocabulary from the corpus
80+
and then learns vector representation of words in the vocabulary.
81+
The vector representation can be used as features in
82+
natural language processing and machine learning algorithms.
83+
84+
We used skip-gram model in our implementation and hierarchical softmax
85+
method to train the model. The variable names in the implementation
86+
matches the original C implementation.
87+
For original C implementation, see https://code.google.com/p/word2vec/
88+
For research papers, see
89+
Efficient Estimation of Word Representations in Vector Space
90+
and
91+
Distributed Representations of Words and Phrases and their Compositionality.
92+
93+
>>> sentence = "a b " * 100 + "a c " * 10
94+
>>> localDoc = [sentence, sentence]
95+
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
96+
>>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
97+
>>> syms = model.findSynonyms("a", 2)
98+
>>> str(syms[0][0])
99+
'b'
100+
>>> str(syms[1][0])
101+
'c'
102+
>>> len(syms)
103+
2
104+
>>> vec = model.transform("a")
105+
>>> len(vec)
106+
10
107+
>>> syms = model.findSynonyms(vec, 2)
108+
>>> str(syms[0][0])
109+
'b'
110+
>>> str(syms[1][0])
111+
'c'
112+
>>> len(syms)
113+
2
114+
"""
115+
def __init__(self):
116+
"""
117+
Construct Word2Vec instance
118+
"""
119+
self.vectorSize = 100
120+
self.learningRate = 0.025
121+
self.numPartitions = 1
122+
self.numIterations = 1
123+
self.seed = 42L
124+
125+
def setVectorSize(self, vectorSize):
126+
"""
127+
Sets vector size (default: 100).
128+
"""
129+
self.vectorSize = vectorSize
130+
return self
131+
132+
def setLearningRate(self, learningRate):
133+
"""
134+
Sets initial learning rate (default: 0.025).
135+
"""
136+
self.learningRate = learningRate
137+
return self
138+
139+
def setNumPartitions(self, numPartitions):
140+
"""
141+
Sets number of partitions (default: 1). Use a small number for accuracy.
142+
"""
143+
self.numPartitions = numPartitions
144+
return self
145+
146+
def setNumIterations(self, numIterations):
147+
"""
148+
Sets number of iterations (default: 1), which should be smaller than or equal to number of
149+
partitions.
150+
"""
151+
self.numIterations = numIterations
152+
return self
153+
154+
def setSeed(self, seed):
155+
"""
156+
Sets random seed.
157+
"""
158+
self.seed = seed
159+
return self
160+
161+
def fit(self, data):
162+
"""
163+
Computes the vector representation of each word in vocabulary.
164+
165+
:param data: training data. RDD of subtype of Iterable[String]
166+
:return: python Word2VecModel instance
167+
"""
168+
sc = data.context
169+
ser = PickleSerializer()
170+
vectorSize = self.vectorSize
171+
learningRate = self.learningRate
172+
numPartitions = self.numPartitions
173+
numIterations = self.numIterations
174+
seed = self.seed
175+
176+
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(
177+
data._to_java_object_rdd(), vectorSize,
178+
learningRate, numPartitions, numIterations, seed)
179+
return Word2VecModel(sc, model)
180+
181+
182+
def _test():
183+
import doctest
184+
from pyspark import SparkContext
185+
globs = globals().copy()
186+
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
187+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
188+
globs['sc'].stop()
189+
if failure_count:
190+
exit(-1)
191+
192+
if __name__ == "__main__":
193+
_test()

python/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ function run_mllib_tests() {
6969
echo "Run mllib tests ..."
7070
run_test "pyspark/mllib/classification.py"
7171
run_test "pyspark/mllib/clustering.py"
72+
run_test "pyspark/mllib/feature.py"
7273
run_test "pyspark/mllib/linalg.py"
7374
run_test "pyspark/mllib/random.py"
7475
run_test "pyspark/mllib/recommendation.py"

0 commit comments

Comments
 (0)