1717
1818package org .apache .spark .mllib .feature
1919
20- import scala .util .Random
21- import scala .collection .mutable .ArrayBuffer
2220import scala .collection .mutable
21+ import scala .collection .mutable .ArrayBuffer
22+ import scala .util .Random
2323
2424import com .github .fommil .netlib .BLAS .{getInstance => blas }
25-
26- import org .apache .spark .annotation .Experimental
27- import org .apache .spark .Logging
28- import org .apache .spark .rdd ._
25+ import org .apache .spark .{HashPartitioner , Logging }
2926import org .apache .spark .SparkContext ._
27+ import org .apache .spark .annotation .Experimental
3028import org .apache .spark .mllib .linalg .{Vector , Vectors }
31- import org .apache .spark .HashPartitioner
32- import org .apache .spark .storage .StorageLevel
3329import org .apache .spark .mllib .rdd .RDDFunctions ._
30+ import org .apache .spark .rdd ._
31+ import org .apache .spark .storage .StorageLevel
3432
3533/**
3634 * Entry in vocabulary
@@ -53,7 +51,7 @@ private case class VocabWord(
5351 *
5452 * We used skip-gram model in our implementation and hierarchical softmax
5553 * method to train the model. The variable names in the implementation
56- * mathes the original C implementation.
54+ * matches the original C implementation.
5755 *
5856 * For original C implementation, see https://code.google.com/p/word2vec/
5957 * For research papers, see
@@ -69,10 +67,14 @@ private case class VocabWord(
6967class Word2Vec (
7068 val size : Int ,
7169 val startingAlpha : Double ,
72- val parallelism : Int = 1 ,
73- val numIterations : Int = 1 )
74- extends Serializable with Logging {
75-
70+ val parallelism : Int ,
71+ val numIterations : Int ) extends Serializable with Logging {
72+
73+ /**
74+ * Word2Vec with a single thread.
75+ */
76+ def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
77+
7678 private val EXP_TABLE_SIZE = 1000
7779 private val MAX_EXP = 6
7880 private val MAX_CODE_LENGTH = 40
@@ -92,7 +94,7 @@ class Word2Vec(
9294 private var vocabHash = mutable.HashMap .empty[String , Int ]
9395 private var alpha = startingAlpha
9496
95- private def learnVocab (words: RDD [String ]){
97+ private def learnVocab (words: RDD [String ]): Unit = {
9698 vocab = words.map(w => (w, 1 ))
9799 .reduceByKey(_ + _)
98100 .map(x => VocabWord (
@@ -126,7 +128,7 @@ class Word2Vec(
126128 expTable
127129 }
128130
129- private def createBinaryTree () {
131+ private def createBinaryTree (): Unit = {
130132 val count = new Array [Long ](vocabSize * 2 + 1 )
131133 val binary = new Array [Int ](vocabSize * 2 + 1 )
132134 val parentNode = new Array [Int ](vocabSize * 2 + 1 )
@@ -208,7 +210,6 @@ class Word2Vec(
208210 * @param dataset an RDD of words
209211 * @return a Word2VecModel
210212 */
211-
212213 def fit [S <: Iterable [String ]](dataset : RDD [S ]): Word2VecModel = {
213214
214215 val words = dataset.flatMap(x => x)
@@ -339,7 +340,7 @@ class Word2Vec(
339340/**
340341* Word2Vec model
341342*/
342- class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
343+ class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
343344
344345 private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
345346 require(v1.length == v2.length, " Vectors should have the same length" )
@@ -358,7 +359,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Seri
358359 def transform (word : String ): Vector = {
359360 val result = model.lookup(word)
360361 if (result.isEmpty) {
361- throw new IllegalStateException (s " ${ word} not in vocabulary " )
362+ throw new IllegalStateException (s " $word not in vocabulary " )
362363 }
363364 else Vectors .dense(result(0 ).map(_.toDouble))
364365 }
0 commit comments