@@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi
2929import org .apache .spark .api .java .{JavaRDD , JavaSparkContext }
3030import org .apache .spark .mllib .classification ._
3131import org .apache .spark .mllib .clustering ._
32+ import org .apache .spark .mllib .feature .Word2Vec
33+ import org .apache .spark .mllib .feature .Word2VecModel
3234import org .apache .spark .mllib .optimization ._
3335import org .apache .spark .mllib .linalg ._
3436import org .apache .spark .mllib .random .{RandomRDDs => RG }
@@ -40,8 +42,6 @@ import org.apache.spark.mllib.tree.impurity._
4042import org .apache .spark .mllib .tree .model .DecisionTreeModel
4143import org .apache .spark .mllib .stat .{MultivariateStatisticalSummary , Statistics }
4244import org .apache .spark .mllib .stat .correlation .CorrelationNames
43- import org .apache .spark .mllib .feature .Word2Vec
44- import org .apache .spark .mllib .feature .Word2VecModel
4545import org .apache .spark .mllib .util .MLUtils
4646import org .apache .spark .rdd .RDD
4747import org .apache .spark .util .Utils
@@ -290,72 +290,34 @@ class PythonMLLibAPI extends Serializable {
290290 * Extra care needs to be taken in the Python code to ensure it gets freed on
291291 * exit; see the Py4J documentation.
292292 * @param dataJRDD Input JavaRDD
293- * @return A handle to java Word2VecModel instance at python side
293+ * @return A handle to java Word2VecModelWrapper instance at python side
294294 */
295- def trainWord2Vec (
296- dataJRDD : JavaRDD [java.util.ArrayList [String ]]
297- ): Word2VecModel = {
298- val data = dataJRDD.rdd.map(_.toArray(new Array [String ](0 )).toSeq).cache()
295+ def trainWord2Vec (dataJRDD : JavaRDD [java.util.ArrayList [String ]]): Word2VecModelWrapper = {
296+ val data = dataJRDD.rdd.cache()
299297 val word2vec = new Word2Vec ()
300298 val model = word2vec.fit(data)
301- model
299+ new Word2VecModelWrapper ( model)
302300 }
303301
304- /**
305- * Java stub for Python mllib Word2VecModel transform
306- * @param model Word2VecModel instance
307- * @param word a word
308- * @return serialized vector representation of word
309- */
310- def Word2VecModelTransform (
311- model : Word2VecModel ,
312- word : String
313- ): Vector = {
314- model.transform(word)
315- }
302+ private [python] class Word2VecModelWrapper (model : Word2VecModel ) {
303+ def transform (word : String ): Vector = {
304+ model.transform(word)
305+ }
316306
317- /**
318- * Java stub for Python mllib Word2VecModel findSynonyms
319- * @param model Word2VecModel instance
320- * @param word a word
321- * @param num number of synonyms to find
322- * @return a java LinkedList containing serialized version of
323- * synonyms and similarities
324- */
325- def Word2VecModelSynonyms (
326- model : Word2VecModel ,
327- word : String ,
328- num : Int
329- ): java.util.List [java.lang.Object ] = {
330- val result = model.findSynonyms(word, num)
331- val similarity = Vectors .dense(result.map(_._2))
332- val words = result.map(_._1)
333- val ret = new java.util.LinkedList [java.lang.Object ]()
334- ret.add(words)
335- ret.add(similarity)
336- ret
337- }
307+ def findSynonyms (word : String , num : Int ): java.util.List [java.lang.Object ] = {
308+ val vec = transform(word)
309+ findSynonyms(vec, num)
310+ }
338311
339- /**
340- * Java stub for Python mllib Word2VecModel findSynonyms
341- * @param model Word2VecModel instance
342- * @param vecBytes serialization of vector representation of words
343- * @param num number of synonyms to find
344- * @return a java LinkedList containing serialized version of
345- * synonyms and similarities
346- */
347- def Word2VecModelSynonyms (
348- model : Word2VecModel ,
349- vec : Vector ,
350- num : Int
351- ): java.util.List [java.lang.Object ] = {
352- val result = model.findSynonyms(vec, num)
353- val similarity = Vectors .dense(result.map(_._2))
354- val words = result.map(_._1)
355- val ret = new java.util.LinkedList [java.lang.Object ]()
356- ret.add(words)
357- ret.add(similarity)
358- ret
312+ def findSynonyms (vector : Vector , num : Int ): java.util.List [java.lang.Object ] = {
313+ val result = model.findSynonyms(vector, num)
314+ val similarity = Vectors .dense(result.map(_._2))
315+ val words = result.map(_._1)
316+ val ret = new java.util.LinkedList [java.lang.Object ]()
317+ ret.add(words)
318+ ret.add(similarity)
319+ ret
320+ }
359321 }
360322
361323 /**
0 commit comments