Skip to content

Commit b51a4cd

Browse files
viiryadavies
authored andcommitted
[SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark
JIRA: https://issues.apache.org/jira/browse/SPARK-12016 We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper when loading it in pyspark. Author: Liang-Chi Hsieh <[email protected]> Closes #10100 from viirya/fix-load-py-wordvecmodel.
1 parent e25f1fe commit b51a4cd

File tree

3 files changed

+67
-34
lines changed

3 files changed

+67
-34
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -680,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable {
680680
}
681681
}
682682

683-
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
684-
def transform(word: String): Vector = {
685-
model.transform(word)
686-
}
687-
688-
/**
689-
* Transforms an RDD of words to its vector representation
690-
* @param rdd an RDD of words
691-
* @return an RDD of vector representations of words
692-
*/
693-
def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
694-
rdd.rdd.map(model.transform)
695-
}
696-
697-
def findSynonyms(word: String, num: Int): JList[Object] = {
698-
val vec = transform(word)
699-
findSynonyms(vec, num)
700-
}
701-
702-
def findSynonyms(vector: Vector, num: Int): JList[Object] = {
703-
val result = model.findSynonyms(vector, num)
704-
val similarity = Vectors.dense(result.map(_._2))
705-
val words = result.map(_._1)
706-
List(words, similarity).map(_.asInstanceOf[Object]).asJava
707-
}
708-
709-
def getVectors: JMap[String, JList[Float]] = {
710-
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
711-
}
712-
713-
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
714-
}
715-
716683
/**
717684
* Java stub for Python mllib DecisionTree.train().
718685
* This stub returns a handle to the Java object instead of the content of the Java object.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.api.python
19+
20+
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
21+
import scala.collection.JavaConverters._
22+
23+
import org.apache.spark.SparkContext
24+
import org.apache.spark.api.java.JavaRDD
25+
import org.apache.spark.mllib.feature.Word2VecModel
26+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
27+
28+
/**
29+
* Wrapper around Word2VecModel to provide helper methods in Python
30+
*/
31+
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
32+
def transform(word: String): Vector = {
33+
model.transform(word)
34+
}
35+
36+
/**
37+
* Transforms an RDD of words to its vector representation
38+
* @param rdd an RDD of words
39+
* @return an RDD of vector representations of words
40+
*/
41+
def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
42+
rdd.rdd.map(model.transform)
43+
}
44+
45+
def findSynonyms(word: String, num: Int): JList[Object] = {
46+
val vec = transform(word)
47+
findSynonyms(vec, num)
48+
}
49+
50+
def findSynonyms(vector: Vector, num: Int): JList[Object] = {
51+
val result = model.findSynonyms(vector, num)
52+
val similarity = Vectors.dense(result.map(_._2))
53+
val words = result.map(_._1)
54+
List(words, similarity).map(_.asInstanceOf[Object]).asJava
55+
}
56+
57+
def getVectors: JMap[String, JList[Float]] = {
58+
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
59+
}
60+
61+
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
62+
}

python/pyspark/mllib/feature.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ def load(cls, sc, path):
504504
"""
505505
jmodel = sc._jvm.org.apache.spark.mllib.feature \
506506
.Word2VecModel.load(sc._jsc.sc(), path)
507-
return Word2VecModel(jmodel)
507+
model = sc._jvm.Word2VecModelWrapper(jmodel)
508+
return Word2VecModel(model)
508509

509510

510511
@ignore_unicode_prefix
@@ -546,6 +547,9 @@ class Word2Vec(object):
546547
>>> sameModel = Word2VecModel.load(sc, path)
547548
>>> model.transform("a") == sameModel.transform("a")
548549
True
550+
>>> syms = sameModel.findSynonyms("a", 2)
551+
>>> [s[0] for s in syms]
552+
[u'b', u'c']
549553
>>> from shutil import rmtree
550554
>>> try:
551555
... rmtree(path)

0 commit comments

Comments
 (0)