Skip to content

Commit 3d8007b

Browse files
committed
fix findSynonyms for vector
1 parent 1bdcd2e commit 3d8007b

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

python/pyspark/mllib/Word2Vec.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,14 @@ def findSynonyms(self, x, num):
6464
Note: local use only
6565
TODO: make findSynonyms usable in RDD operations from python side
6666
"""
67-
jlist = self._java_model.findSynonyms(x, num)
68-
words, similarity = PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(jlist)))
67+
ser = PickleSerializer()
68+
if type(x) == str:
69+
jlist = self._java_model.findSynonyms(x, num)
70+
else:
71+
bytes = bytearray(ser.dumps(_convert_to_vector(x)))
72+
vec = self._sc._jvm.SerDe.loads(bytes)
73+
jlist = self._java_model.findSynonyms(vec, num)
74+
words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist)))
6975
return zip(words, similarity)
7076

7177

@@ -100,6 +106,13 @@ class Word2Vec(object):
100106
>>> vec = model.transform("a")
101107
>>> len(vec)
102108
10
109+
>>> syms = model.findSynonyms(vec, 2)
110+
>>> str(syms[0][0])
111+
'b'
112+
>>> str(syms[1][0])
113+
'c'
114+
>>> len(syms)
115+
2
103116
"""
104117
def __init__(self):
105118
"""

0 commit comments

Comments
 (0)