@@ -50,24 +50,27 @@ private case class VocabWord(
5050 * natural language processing and machine learning algorithms.
5151 *
5252 * We used skip-gram model in our implementation and hierarchical softmax
53- * method to train the model.
53+ * method to train the model. The variable names in the implementation
54+ * mathes the original C implementation.
5455 *
5556 * For original C implementation, see https://code.google.com/p/word2vec/
5657 * For research papers, see
5758 * Efficient Estimation of Word Representations in Vector Space
5859 * and
59- * Distributed Representations of Words and Phrases and their Compositionality
60+ * Distributed Representations of Words and Phrases and their Compositionality.
6061 * @param size vector dimension
6162 * @param startingAlpha initial learning rate
6263 * @param window context words from [-window, window]
6364 * @param minCount minimum frequncy to consider a vocabulary word
65+ * @param parallelisum number of partitions to run Word2Vec
6466 */
6567@ Experimental
6668class Word2Vec (
6769 val size : Int ,
6870 val startingAlpha : Double ,
6971 val window : Int ,
70- val minCount : Int )
72+ val minCount : Int ,
73+ val parallelism : Int = 1 )
7174 extends Serializable with Logging {
7275
7376 private val EXP_TABLE_SIZE = 1000
@@ -237,7 +240,7 @@ class Word2Vec(
237240 }
238241 }
239242
240- val newSentences = sentences.repartition(1 ).cache()
243+ val newSentences = sentences.repartition(parallelism ).cache()
241244 val temp = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
242245 val (aggSyn0, _, _, _) =
243246 // TODO: broadcast temp instead of serializing it directly
@@ -248,7 +251,7 @@ class Word2Vec(
248251 var wc = wordCount
249252 if (wordCount - lastWordCount > 10000 ) {
250253 lwc = wordCount
251- alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1 ))
254+ alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
252255 if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
253256 logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
254257 }
@@ -296,7 +299,7 @@ class Word2Vec(
296299 val n = syn0_1.length
297300 blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
298301 blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
299- (syn0_1, syn0_2 , lwc_1 + lwc_2, wc_1 + wc_2)
302+ (syn0_1, syn1_1 , lwc_1 + lwc_2, wc_1 + wc_2)
300303 })
301304
302305 val wordMap = new Array [(String , Array [Double ])](vocabSize)
@@ -309,7 +312,7 @@ class Word2Vec(
309312 i += 1
310313 }
311314 val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
312- .partitionBy(new HashPartitioner (modelPartitionNum))
315+ .partitionBy(new HashPartitioner (modelPartitionNum)).cache()
313316 new Word2VecModel (modelRDD)
314317 }
315318}
0 commit comments