Skip to content

Commit 89490bf

Browse files
committed
add tests and Word2VecModelWrapper
1 parent 78bbb53 commit 89490bf

File tree

2 files changed

+41
-78
lines changed

2 files changed

+41
-78
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 23 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi
2929
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
3030
import org.apache.spark.mllib.classification._
3131
import org.apache.spark.mllib.clustering._
32+
import org.apache.spark.mllib.feature.Word2Vec
33+
import org.apache.spark.mllib.feature.Word2VecModel
3234
import org.apache.spark.mllib.optimization._
3335
import org.apache.spark.mllib.linalg._
3436
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -40,8 +42,6 @@ import org.apache.spark.mllib.tree.impurity._
4042
import org.apache.spark.mllib.tree.model.DecisionTreeModel
4143
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
4244
import org.apache.spark.mllib.stat.correlation.CorrelationNames
43-
import org.apache.spark.mllib.feature.Word2Vec
44-
import org.apache.spark.mllib.feature.Word2VecModel
4545
import org.apache.spark.mllib.util.MLUtils
4646
import org.apache.spark.rdd.RDD
4747
import org.apache.spark.util.Utils
@@ -290,72 +290,34 @@ class PythonMLLibAPI extends Serializable {
290290
* Extra care needs to be taken in the Python code to ensure it gets freed on
291291
* exit; see the Py4J documentation.
292292
* @param dataJRDD Input JavaRDD
293-
* @return A handle to java Word2VecModel instance at python side
293+
* @return A handle to java Word2VecModelWrapper instance at python side
294294
*/
295-
def trainWord2Vec(
296-
dataJRDD: JavaRDD[java.util.ArrayList[String]]
297-
): Word2VecModel = {
298-
val data = dataJRDD.rdd.map(_.toArray(new Array[String](0)).toSeq).cache()
295+
def trainWord2Vec(dataJRDD: JavaRDD[java.util.ArrayList[String]]): Word2VecModelWrapper = {
296+
val data = dataJRDD.rdd.cache()
299297
val word2vec = new Word2Vec()
300298
val model = word2vec.fit(data)
301-
model
299+
new Word2VecModelWrapper(model)
302300
}
303301

304-
/**
305-
* Java stub for Python mllib Word2VecModel transform
306-
* @param model Word2VecModel instance
307-
* @param word a word
308-
* @return serialized vector representation of word
309-
*/
310-
def Word2VecModelTransform(
311-
model: Word2VecModel,
312-
word: String
313-
): Vector = {
314-
model.transform(word)
315-
}
302+
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
303+
def transform(word: String): Vector = {
304+
model.transform(word)
305+
}
316306

317-
/**
318-
* Java stub for Python mllib Word2VecModel findSynonyms
319-
* @param model Word2VecModel instance
320-
* @param word a word
321-
* @param num number of synonyms to find
322-
* @return a java LinkedList containing serialized version of
323-
* synonyms and similarities
324-
*/
325-
def Word2VecModelSynonyms(
326-
model: Word2VecModel,
327-
word: String,
328-
num: Int
329-
): java.util.List[java.lang.Object] = {
330-
val result = model.findSynonyms(word, num)
331-
val similarity = Vectors.dense(result.map(_._2))
332-
val words = result.map(_._1)
333-
val ret = new java.util.LinkedList[java.lang.Object]()
334-
ret.add(words)
335-
ret.add(similarity)
336-
ret
337-
}
307+
def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
308+
val vec = transform(word)
309+
findSynonyms(vec, num)
310+
}
338311

339-
/**
340-
* Java stub for Python mllib Word2VecModel findSynonyms
341-
* @param model Word2VecModel instance
342-
* @param vecBytes serialization of vector representation of words
343-
* @param num number of synonyms to find
344-
* @return a java LinkedList containing serialized version of
345-
* synonyms and similarities
346-
*/
347-
def Word2VecModelSynonyms(
348-
model: Word2VecModel,
349-
vec: Vector,
350-
num: Int
351-
): java.util.List[java.lang.Object] = {
352-
val result = model.findSynonyms(vec, num)
353-
val similarity = Vectors.dense(result.map(_._2))
354-
val words = result.map(_._1)
355-
val ret = new java.util.LinkedList[java.lang.Object]()
356-
ret.add(words)
357-
ret.add(similarity)
358-
ret
312+
def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
313+
val result = model.findSynonyms(vector, num)
314+
val similarity = Vectors.dense(result.map(_._2))
315+
val words = result.map(_._1)
316+
val ret = new java.util.LinkedList[java.lang.Object]()
317+
ret.add(words)
318+
ret.add(similarity)
319+
ret
320+
}
359321
}
360322

361323
/**

python/pyspark/mllib/Word2Vec.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
Python package for Word2Vec in MLlib.
2020
"""
2121

22-
from functools import wraps
23-
2422
from pyspark import PickleSerializer
2523

2624
from pyspark.mllib.linalg import _convert_to_vector
@@ -44,21 +42,13 @@ def __del__(self):
4442
self._sc._gateway.detach(self._java_model)
4543

4644
def transform(self, word):
47-
pythonAPI = self._sc._jvm.PythonMLLibAPI()
48-
result = pythonAPI.Word2VecModelTransform(self._java_model, word)
45+
result = self._java_model.transform(word)
4946
return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
5047

5148
def findSynonyms(self, x, num):
52-
SerDe = self._sc._jvm.SerDe
53-
ser = PickleSerializer()
54-
pythonAPI = self._sc._jvm.PythonMLLibAPI()
55-
if type(x) == str:
56-
jlist = pythonAPI.Word2VecModelSynonyms(self._java_model, x, num)
57-
else:
58-
bytes = bytearray(ser.dumps(_convert_to_vector(x)))
59-
vec = self._sc._jvm.SerDe.loads(bytes)
60-
jlist = pythonAPI.Word2VecModelSynonyms(self._java_model, vec, num)
61-
return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(jlist)))
49+
jlist = self._java_model.findSynonyms(x, num)
50+
words, similarity = PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(jlist)))
51+
return zip(words, similarity)
6252

6353

6454
class Word2Vec(object):
@@ -77,12 +67,22 @@ class Word2Vec(object):
7767
Efficient Estimation of Word Representations in Vector Space
7868
and
7969
Distributed Representations of Words and Phrases and their Compositionality.
70+
>>> sentence = "a b " * 100 + "a c " * 10
71+
>>> localDoc = [sentence, sentence]
72+
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
73+
>>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
74+
>>> syms = model.findSynonyms("a", 2)
75+
>>> str(syms[0][0])
76+
'b'
77+
>>> str(syms[1][0])
78+
'c'
8079
"""
8180
def __init__(self):
8281
self.vectorSize = 100
8382
self.startingAlpha = 0.025
8483
self.numPartitions = 1
8584
self.numIterations = 1
85+
self.seed = 42L
8686

8787
def setVectorSize(self, vectorSize):
8888
self.vectorSize = vectorSize
@@ -100,10 +100,11 @@ def setNumIterations(self, numIterations):
100100
self.numIterations = numIterations
101101
return self
102102

103+
def setSeed(self, seed):
104+
self.seed = seed
105+
return self
106+
103107
def fit(self, data):
104-
"""
105-
:param data: Input RDD
106-
"""
107108
sc = data.context
108109
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(data._to_java_object_rdd())
109110
return Word2VecModel(sc, model)

0 commit comments

Comments
 (0)