@@ -70,7 +70,8 @@ class Word2Vec(
7070 val startingAlpha : Double ,
7171 val window : Int ,
7272 val minCount : Int ,
73- val parallelism : Int = 1 )
73+ val parallelism : Int = 1 ,
74+ val numIterations : Int = 1 )
7475 extends Serializable with Logging {
7576
7677 private val EXP_TABLE_SIZE = 1000
@@ -241,73 +242,80 @@ class Word2Vec(
241242 }
242243
243244 val newSentences = sentences.repartition(parallelism).cache()
244- val temp = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
245- val (aggSyn0, _, _, _) =
246- // TODO: broadcast temp instead of serializing it directly
247- // or initialize the model in each executor
248- newSentences.aggregate((temp.clone(), new Array [Double ](vocabSize * layer1Size), 0 , 0 ))(
249- seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
250- var lwc = lastWordCount
251- var wc = wordCount
252- if (wordCount - lastWordCount > 10000 ) {
253- lwc = wordCount
254- alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
255- if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
256- logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
257- }
258- wc += sentence.size
259- var pos = 0
260- while (pos < sentence.size) {
261- val word = sentence(pos)
262- // TODO: fix random seed
263- val b = Random .nextInt(window)
264- // Train Skip-gram
265- var a = b
266- while (a < window * 2 + 1 - b) {
267- if (a != window) {
268- val c = pos - window + a
269- if (c >= 0 && c < sentence.size) {
270- val lastWord = sentence(c)
271- val l1 = lastWord * layer1Size
272- val neu1e = new Array [Double ](layer1Size)
273- // Hierarchical softmax
274- var d = 0
275- while (d < vocab(word).codeLen) {
276- val l2 = vocab(word).point(d) * layer1Size
277- // Propagate hidden -> output
278- var f = blas.ddot(layer1Size, syn0, l1, 1 , syn1, l2, 1 )
279- if (f > - MAX_EXP && f < MAX_EXP ) {
280- val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
281- f = expTable.value(ind)
282- val g = (1 - vocab(word).code(d) - f) * alpha
283- blas.daxpy(layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
284- blas.daxpy(layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
245+ var syn0Global
246+ = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
247+ var syn1Global = new Array [Double ](vocabSize * layer1Size)
248+
249+ for (iter <- 1 to numIterations) {
250+ val (aggSyn0, aggSyn1, _, _) =
251+ // TODO: broadcast temp instead of serializing it directly
252+ // or initialize the model in each executor
253+ newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0 , 0 ))(
254+ seqOp = (c, v) => (c, v) match {
255+ case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
256+ var lwc = lastWordCount
257+ var wc = wordCount
258+ if (wordCount - lastWordCount > 10000 ) {
259+ lwc = wordCount
260+ alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
261+ if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
262+ logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
263+ }
264+ wc += sentence.size
265+ var pos = 0
266+ while (pos < sentence.size) {
267+ val word = sentence(pos)
268+ // TODO: fix random seed
269+ val b = Random .nextInt(window)
270+ // Train Skip-gram
271+ var a = b
272+ while (a < window * 2 + 1 - b) {
273+ if (a != window) {
274+ val c = pos - window + a
275+ if (c >= 0 && c < sentence.size) {
276+ val lastWord = sentence(c)
277+ val l1 = lastWord * layer1Size
278+ val neu1e = new Array [Double ](layer1Size)
279+ // Hierarchical softmax
280+ var d = 0
281+ while (d < vocab(word).codeLen) {
282+ val l2 = vocab(word).point(d) * layer1Size
283+ // Propagate hidden -> output
284+ var f = blas.ddot(layer1Size, syn0, l1, 1 , syn1, l2, 1 )
285+ if (f > - MAX_EXP && f < MAX_EXP ) {
286+ val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
287+ f = expTable.value(ind)
288+ val g = (1 - vocab(word).code(d) - f) * alpha
289+ blas.daxpy(layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
290+ blas.daxpy(layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
291+ }
292+ d += 1
285293 }
286- d += 1
294+ blas.daxpy(layer1Size, 1.0 , neu1e, 0 , 1 , syn0, l1, 1 )
287295 }
288- blas.daxpy(layer1Size, 1.0 , neu1e, 0 , 1 , syn0, l1, 1 )
289296 }
297+ a += 1
290298 }
291- a += 1
299+ pos += 1
292300 }
293- pos += 1
294- }
295- (syn0, syn1, lwc, wc)
296- },
297- combOp = (c1, c2) => (c1, c2) match {
298- case ((syn0_1, syn1_1, lwc_1, wc_1), ( syn0_2, syn1_2, lwc_2, wc_2)) =>
299- val n = syn0_1.length
300- blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
301- blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
302- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
303- })
304-
301+ (syn0, syn1, lwc, wc)
302+ },
303+ combOp = (c1, c2) => (c1, c2) match {
304+ case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
305+ val n = syn0_1.length
306+ blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
307+ blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
308+ (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2 )
309+ } )
310+ syn0Global = aggSyn0
311+ syn1Global = aggSyn1
312+ }
305313 val wordMap = new Array [(String , Array [Double ])](vocabSize)
306314 var i = 0
307315 while (i < vocabSize) {
308316 val word = vocab(i).word
309317 val vector = new Array [Double ](layer1Size)
310- Array .copy(aggSyn0 , i * layer1Size, vector, 0 , layer1Size)
318+ Array .copy(syn0Global , i * layer1Size, vector, 0 , layer1Size)
311319 wordMap(i) = (word, vector)
312320 i += 1
313321 }
@@ -398,7 +406,9 @@ object Word2Vec{
398406 size : Int ,
399407 startingAlpha : Double ,
400408 window : Int ,
401- minCount : Int ): Word2VecModel = {
402- new Word2Vec (size,startingAlpha, window, minCount).fit[S ](input)
409+ minCount : Int ,
410+ parallelism : Int = 1 ,
411+ numIterations: Int = 1 ): Word2VecModel = {
412+ new Word2Vec (size,startingAlpha, window, minCount, parallelism, numIterations).fit[S ](input)
403413 }
404414}
0 commit comments