Skip to content

Commit 6bcc8be

Browse files
author
Liquan Pei
committed
add multiple iteration support
1 parent 720b5a3 commit 6bcc8be

File tree

1 file changed

+70
-60
lines changed

1 file changed

+70
-60
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 70 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ class Word2Vec(
7070
val startingAlpha: Double,
7171
val window: Int,
7272
val minCount: Int,
73-
val parallelism:Int = 1)
73+
val parallelism:Int = 1,
74+
val numIterations:Int = 1)
7475
extends Serializable with Logging {
7576

7677
private val EXP_TABLE_SIZE = 1000
@@ -241,73 +242,80 @@ class Word2Vec(
241242
}
242243

243244
val newSentences = sentences.repartition(parallelism).cache()
244-
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
245-
val (aggSyn0, _, _, _) =
246-
// TODO: broadcast temp instead of serializing it directly
247-
// or initialize the model in each executor
248-
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
249-
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
250-
var lwc = lastWordCount
251-
var wc = wordCount
252-
if (wordCount - lastWordCount > 10000) {
253-
lwc = wordCount
254-
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
255-
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
256-
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
257-
}
258-
wc += sentence.size
259-
var pos = 0
260-
while (pos < sentence.size) {
261-
val word = sentence(pos)
262-
// TODO: fix random seed
263-
val b = Random.nextInt(window)
264-
// Train Skip-gram
265-
var a = b
266-
while (a < window * 2 + 1 - b) {
267-
if (a != window) {
268-
val c = pos - window + a
269-
if (c >= 0 && c < sentence.size) {
270-
val lastWord = sentence(c)
271-
val l1 = lastWord * layer1Size
272-
val neu1e = new Array[Double](layer1Size)
273-
// Hierarchical softmax
274-
var d = 0
275-
while (d < vocab(word).codeLen) {
276-
val l2 = vocab(word).point(d) * layer1Size
277-
// Propagate hidden -> output
278-
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
279-
if (f > -MAX_EXP && f < MAX_EXP) {
280-
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
281-
f = expTable.value(ind)
282-
val g = (1 - vocab(word).code(d) - f) * alpha
283-
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
284-
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
245+
var syn0Global
246+
= Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
247+
var syn1Global = new Array[Double](vocabSize * layer1Size)
248+
249+
for(iter <- 1 to numIterations) {
250+
val (aggSyn0, aggSyn1, _, _) =
251+
// TODO: broadcast temp instead of serializing it directly
252+
// or initialize the model in each executor
253+
newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))(
254+
seqOp = (c, v) => (c, v) match {
255+
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
256+
var lwc = lastWordCount
257+
var wc = wordCount
258+
if (wordCount - lastWordCount > 10000) {
259+
lwc = wordCount
260+
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
261+
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
262+
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
263+
}
264+
wc += sentence.size
265+
var pos = 0
266+
while (pos < sentence.size) {
267+
val word = sentence(pos)
268+
// TODO: fix random seed
269+
val b = Random.nextInt(window)
270+
// Train Skip-gram
271+
var a = b
272+
while (a < window * 2 + 1 - b) {
273+
if (a != window) {
274+
val c = pos - window + a
275+
if (c >= 0 && c < sentence.size) {
276+
val lastWord = sentence(c)
277+
val l1 = lastWord * layer1Size
278+
val neu1e = new Array[Double](layer1Size)
279+
// Hierarchical softmax
280+
var d = 0
281+
while (d < vocab(word).codeLen) {
282+
val l2 = vocab(word).point(d) * layer1Size
283+
// Propagate hidden -> output
284+
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
285+
if (f > -MAX_EXP && f < MAX_EXP) {
286+
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
287+
f = expTable.value(ind)
288+
val g = (1 - vocab(word).code(d) - f) * alpha
289+
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
290+
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
291+
}
292+
d += 1
285293
}
286-
d += 1
294+
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
287295
}
288-
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
289296
}
297+
a += 1
290298
}
291-
a += 1
299+
pos += 1
292300
}
293-
pos += 1
294-
}
295-
(syn0, syn1, lwc, wc)
296-
},
297-
combOp = (c1, c2) => (c1, c2) match {
298-
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
299-
val n = syn0_1.length
300-
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
301-
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
302-
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
303-
})
304-
301+
(syn0, syn1, lwc, wc)
302+
},
303+
combOp = (c1, c2) => (c1, c2) match {
304+
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
305+
val n = syn0_1.length
306+
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
307+
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
308+
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
309+
})
310+
syn0Global = aggSyn0
311+
syn1Global = aggSyn1
312+
}
305313
val wordMap = new Array[(String, Array[Double])](vocabSize)
306314
var i = 0
307315
while (i < vocabSize) {
308316
val word = vocab(i).word
309317
val vector = new Array[Double](layer1Size)
310-
Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size)
318+
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
311319
wordMap(i) = (word, vector)
312320
i += 1
313321
}
@@ -398,7 +406,9 @@ object Word2Vec{
398406
size: Int,
399407
startingAlpha: Double,
400408
window: Int,
401-
minCount: Int): Word2VecModel = {
402-
new Word2Vec(size,startingAlpha, window, minCount).fit[S](input)
409+
minCount: Int,
410+
parallelism: Int = 1,
411+
numIterations:Int = 1): Word2VecModel = {
412+
new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input)
403413
}
404414
}

0 commit comments

Comments
 (0)