@@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature
1919
2020import scala .collection .mutable
2121import scala .collection .mutable .ArrayBuffer
22- import scala .util .Random
2322
2423import com .github .fommil .netlib .BLAS .{getInstance => blas }
25- import org .apache .spark .{HashPartitioner , Logging }
24+
25+ import org .apache .spark .Logging
2626import org .apache .spark .SparkContext ._
2727import org .apache .spark .annotation .Experimental
2828import org .apache .spark .mllib .linalg .{Vector , Vectors }
2929import org .apache .spark .mllib .rdd .RDDFunctions ._
3030import org .apache .spark .rdd ._
31- import org .apache .spark .storage .StorageLevel
31+ import org .apache .spark .util .Utils
32+ import org .apache .spark .util .random .XORShiftRandom
3233
3334/**
3435 * Entry in vocabulary
@@ -58,29 +59,63 @@ private case class VocabWord(
5859 * Efficient Estimation of Word Representations in Vector Space
5960 * and
6061 * Distributed Representations of Words and Phrases and their Compositionality.
61- * @param size vector dimension
62- * @param startingAlpha initial learning rate
63- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64- * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6562 */
6663@ Experimental
67- class Word2Vec (
68- val size : Int ,
69- val startingAlpha : Double ,
70- val parallelism : Int ,
71- val numIterations : Int ) extends Serializable with Logging {
64+ class Word2Vec extends Serializable with Logging {
65+
66+ private var vectorSize = 100
67+ private var startingAlpha = 0.025
68+ private var numPartitions = 1
69+ private var numIterations = 1
70+ private var seed = Utils .random.nextLong()
71+
72+ /**
73+ * Sets vector size (default: 100).
74+ */
75+ def setVectorSize (vectorSize : Int ): this .type = {
76+ this .vectorSize = vectorSize
77+ this
78+ }
79+
80+ /**
81+ * Sets initial learning rate (default: 0.025).
82+ */
83+ def setLearningRate (learningRate : Double ): this .type = {
84+ this .startingAlpha = learningRate
85+ this
86+ }
7287
7388 /**
74- * Word2Vec with a single thread .
89+ * Sets number of partitions (default: 1). Use a small number for accuracy .
7590 */
76- def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
91+ def setNumPartitions (numPartitions : Int ): this .type = {
92+ require(numPartitions > 0 , s " numPartitions must be greater than 0 but got $numPartitions" )
93+ this .numPartitions = numPartitions
94+ this
95+ }
96+
97+ /**
98+ * Sets number of iterations (default: 1), which should be smaller than or equal to number of
99+ * partitions.
100+ */
101+ def setNumIterations (numIterations : Int ): this .type = {
102+ this .numIterations = numIterations
103+ this
104+ }
105+
106+ /**
107+ * Sets random seed (default: a random long integer).
108+ */
109+ def setSeed (seed : Long ): this .type = {
110+ this .seed = seed
111+ this
112+ }
77113
78114 private val EXP_TABLE_SIZE = 1000
79115 private val MAX_EXP = 6
80116 private val MAX_CODE_LENGTH = 40
81117 private val MAX_SENTENCE_LENGTH = 1000
82- private val layer1Size = size
83- private val modelPartitionNum = 100
118+ private val layer1Size = vectorSize
84119
85120 /** context words from [-window, window] */
86121 private val window = 5
@@ -94,12 +129,12 @@ class Word2Vec(
94129 private var vocabHash = mutable.HashMap .empty[String , Int ]
95130 private var alpha = startingAlpha
96131
97- private def learnVocab (words: RDD [String ]): Unit = {
132+ private def learnVocab (words : RDD [String ]): Unit = {
98133 vocab = words.map(w => (w, 1 ))
99134 .reduceByKey(_ + _)
100135 .map(x => VocabWord (
101- x._1,
102- x._2,
136+ x._1,
137+ x._2,
103138 new Array [Int ](MAX_CODE_LENGTH ),
104139 new Array [Int ](MAX_CODE_LENGTH ),
105140 0 ))
@@ -245,32 +280,32 @@ class Word2Vec(
245280 }
246281 }
247282
248- val newSentences = sentences.repartition(parallelism).cache()
283+ val newSentences = sentences.repartition(numPartitions).cache()
284+ val initRandom = new XORShiftRandom (seed)
249285 var syn0Global =
250- Array .fill[Float ](vocabSize * layer1Size)((Random .nextFloat() - 0.5f ) / layer1Size)
286+ Array .fill[Float ](vocabSize * layer1Size)((initRandom .nextFloat() - 0.5f ) / layer1Size)
251287 var syn1Global = new Array [Float ](vocabSize * layer1Size)
252-
253- for (iter <- 1 to numIterations) {
254- val (aggSyn0, aggSyn1, _, _) =
255- // TODO: broadcast temp instead of serializing it directly
256- // or initialize the model in each executor
257- newSentences.treeAggregate((syn0Global, syn1Global, 0 , 0 ))(
258- seqOp = (c, v) => (c, v) match {
288+
289+ for (k <- 1 to numIterations) {
290+ val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
291+ val random = new XORShiftRandom (seed ^ ((idx + 1 ) << 16 ) ^ ((- k - 1 ) << 8 ))
292+ val model = iter.foldLeft((syn0Global, syn1Global, 0 , 0 )) {
259293 case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
260294 var lwc = lastWordCount
261- var wc = wordCount
295+ var wc = wordCount
262296 if (wordCount - lastWordCount > 10000 ) {
263297 lwc = wordCount
264- alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
298+ // TODO: discount by iteration?
299+ alpha =
300+ startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1 ))
265301 if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
266302 logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
267303 }
268304 wc += sentence.size
269305 var pos = 0
270306 while (pos < sentence.size) {
271307 val word = sentence(pos)
272- // TODO: fix random seed
273- val b = Random .nextInt(window)
308+ val b = random.nextInt(window)
274309 // Train Skip-gram
275310 var a = b
276311 while (a < window * 2 + 1 - b) {
@@ -280,7 +315,7 @@ class Word2Vec(
280315 val lastWord = sentence(c)
281316 val l1 = lastWord * layer1Size
282317 val neu1e = new Array [Float ](layer1Size)
283- // Hierarchical softmax
318+ // Hierarchical softmax
284319 var d = 0
285320 while (d < bcVocab.value(word).codeLen) {
286321 val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +338,44 @@ class Word2Vec(
303338 pos += 1
304339 }
305340 (syn0, syn1, lwc, wc)
306- },
307- combOp = (c1, c2) => (c1, c2) match {
308- case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
309- val n = syn0_1.length
310- val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311- val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312- blas.sscal(n, weight1, syn0_1, 1 )
313- blas.sscal(n, weight1, syn1_1, 1 )
314- blas.saxpy(n, weight2, syn0_2, 1 , syn0_1, 1 )
315- blas.saxpy(n, weight2, syn1_2, 1 , syn1_1, 1 )
316- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
317- })
341+ }
342+ Iterator (model)
343+ }
344+ val (aggSyn0, aggSyn1, _, _) =
345+ partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
346+ val n = syn0_1.length
347+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
348+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
349+ blas.sscal(n, weight1, syn0_1, 1 )
350+ blas.sscal(n, weight1, syn1_1, 1 )
351+ blas.saxpy(n, weight2, syn0_2, 1 , syn0_1, 1 )
352+ blas.saxpy(n, weight2, syn1_2, 1 , syn1_1, 1 )
353+ (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
354+ }
318355 syn0Global = aggSyn0
319356 syn1Global = aggSyn1
320357 }
321358 newSentences.unpersist()
322359
323- val wordMap = new Array [( String , Array [Float ])](vocabSize)
360+ val word2VecMap = mutable. HashMap .empty[ String , Array [Float ]]
324361 var i = 0
325362 while (i < vocabSize) {
326363 val word = bcVocab.value(i).word
327364 val vector = new Array [Float ](layer1Size)
328365 Array .copy(syn0Global, i * layer1Size, vector, 0 , layer1Size)
329- wordMap(i) = ( word, vector)
366+ word2VecMap += word -> vector
330367 i += 1
331368 }
332- val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
333- .partitionBy(new HashPartitioner (modelPartitionNum))
334- .persist(StorageLevel .MEMORY_AND_DISK )
335-
336- new Word2VecModel (modelRDD)
369+
370+ new Word2VecModel (word2VecMap.toMap)
337371 }
338372}
339373
340374/**
341375* Word2Vec model
342- */
343- class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
376+ */
377+ class Word2VecModel private [mllib] (
378+ private val model : Map [String , Array [Float ]]) extends Serializable {
344379
345380 private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
346381 require(v1.length == v2.length, " Vectors should have the same length" )
@@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
357392 * @return vector representation of word
358393 */
359394 def transform (word : String ): Vector = {
360- val result = model.lookup(word)
361- if (result.isEmpty) {
362- throw new IllegalStateException (s " $word not in vocabulary " )
395+ model.get(word) match {
396+ case Some (vec) =>
397+ Vectors .dense(vec.map(_.toDouble))
398+ case None =>
399+ throw new IllegalStateException (s " $word not in vocabulary " )
363400 }
364- else Vectors .dense(result(0 ).map(_.toDouble))
365401 }
366402
367403 /**
@@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
392428 */
393429 def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
394430 require(num > 0 , " Number of similar words should > 0" )
395- val topK = model.map { case (w, vec) =>
396- (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
397- .sortByKey(ascending = false )
398- .take(num + 1 )
399- .map(_.swap)
400- .tail
401-
402- topK
403- }
404- }
405-
406- object Word2Vec {
407- /**
408- * Train Word2Vec model
409- * @param input RDD of words
410- * @param size vector dimension
411- * @param startingAlpha initial learning rate
412- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
413- * @param numIterations number of iterations, should be smaller than or equal to parallelism
414- * @return Word2Vec model
415- */
416- def train [S <: Iterable [String ]](
417- input : RDD [S ],
418- size : Int ,
419- startingAlpha : Double ,
420- parallelism : Int = 1 ,
421- numIterations: Int = 1 ): Word2VecModel = {
422- new Word2Vec (size,startingAlpha, parallelism, numIterations).fit[S ](input)
431+ // TODO: optimize top-k
432+ val fVector = vector.toArray.map(_.toFloat)
433+ model.mapValues(vec => cosineSimilarity(fVector, vec))
434+ .toSeq
435+ .sortBy(- _._2)
436+ .take(num + 1 )
437+ .tail
438+ .toArray
423439 }
424440}
0 commit comments