Skip to content

Commit cc491f6

Browse files
committed
[SPARK-2864][MLLIB] fix random seed in word2vec; move model to local
It also moves the model to local in order to map `RDD[String]` to `RDD[Vector]`. Ishiihara Author: Xiangrui Meng <[email protected]> Closes #1790 from mengxr/word2vec-fix and squashes the following commits: a87146c [Xiangrui Meng] add setters and make a default constructor e5c923b [Xiangrui Meng] fix random seed in word2vec; move model to local
1 parent 41e0a21 commit cc491f6

File tree

2 files changed

+106
-97
lines changed

2 files changed

+106
-97
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 102 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature
1919

2020
import scala.collection.mutable
2121
import scala.collection.mutable.ArrayBuffer
22-
import scala.util.Random
2322

2423
import com.github.fommil.netlib.BLAS.{getInstance => blas}
25-
import org.apache.spark.{HashPartitioner, Logging}
24+
25+
import org.apache.spark.Logging
2626
import org.apache.spark.SparkContext._
2727
import org.apache.spark.annotation.Experimental
2828
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2929
import org.apache.spark.mllib.rdd.RDDFunctions._
3030
import org.apache.spark.rdd._
31-
import org.apache.spark.storage.StorageLevel
31+
import org.apache.spark.util.Utils
32+
import org.apache.spark.util.random.XORShiftRandom
3233

3334
/**
3435
* Entry in vocabulary
@@ -58,29 +59,63 @@ private case class VocabWord(
5859
* Efficient Estimation of Word Representations in Vector Space
5960
* and
6061
* Distributed Representations of Words and Phrases and their Compositionality.
61-
* @param size vector dimension
62-
* @param startingAlpha initial learning rate
63-
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64-
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6562
*/
6663
@Experimental
67-
class Word2Vec(
68-
val size: Int,
69-
val startingAlpha: Double,
70-
val parallelism: Int,
71-
val numIterations: Int) extends Serializable with Logging {
64+
class Word2Vec extends Serializable with Logging {
65+
66+
private var vectorSize = 100
67+
private var startingAlpha = 0.025
68+
private var numPartitions = 1
69+
private var numIterations = 1
70+
private var seed = Utils.random.nextLong()
71+
72+
/**
73+
* Sets vector size (default: 100).
74+
*/
75+
def setVectorSize(vectorSize: Int): this.type = {
76+
this.vectorSize = vectorSize
77+
this
78+
}
79+
80+
/**
81+
* Sets initial learning rate (default: 0.025).
82+
*/
83+
def setLearningRate(learningRate: Double): this.type = {
84+
this.startingAlpha = learningRate
85+
this
86+
}
7287

7388
/**
74-
* Word2Vec with a single thread.
89+
* Sets number of partitions (default: 1). Use a small number for accuracy.
7590
*/
76-
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
91+
def setNumPartitions(numPartitions: Int): this.type = {
92+
require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
93+
this.numPartitions = numPartitions
94+
this
95+
}
96+
97+
/**
98+
* Sets number of iterations (default: 1), which should be smaller than or equal to number of
99+
* partitions.
100+
*/
101+
def setNumIterations(numIterations: Int): this.type = {
102+
this.numIterations = numIterations
103+
this
104+
}
105+
106+
/**
107+
* Sets random seed (default: a random long integer).
108+
*/
109+
def setSeed(seed: Long): this.type = {
110+
this.seed = seed
111+
this
112+
}
77113

78114
private val EXP_TABLE_SIZE = 1000
79115
private val MAX_EXP = 6
80116
private val MAX_CODE_LENGTH = 40
81117
private val MAX_SENTENCE_LENGTH = 1000
82-
private val layer1Size = size
83-
private val modelPartitionNum = 100
118+
private val layer1Size = vectorSize
84119

85120
/** context words from [-window, window] */
86121
private val window = 5
@@ -94,12 +129,12 @@ class Word2Vec(
94129
private var vocabHash = mutable.HashMap.empty[String, Int]
95130
private var alpha = startingAlpha
96131

97-
private def learnVocab(words:RDD[String]): Unit = {
132+
private def learnVocab(words: RDD[String]): Unit = {
98133
vocab = words.map(w => (w, 1))
99134
.reduceByKey(_ + _)
100135
.map(x => VocabWord(
101-
x._1,
102-
x._2,
136+
x._1,
137+
x._2,
103138
new Array[Int](MAX_CODE_LENGTH),
104139
new Array[Int](MAX_CODE_LENGTH),
105140
0))
@@ -245,32 +280,32 @@ class Word2Vec(
245280
}
246281
}
247282

248-
val newSentences = sentences.repartition(parallelism).cache()
283+
val newSentences = sentences.repartition(numPartitions).cache()
284+
val initRandom = new XORShiftRandom(seed)
249285
var syn0Global =
250-
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
286+
Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
251287
var syn1Global = new Array[Float](vocabSize * layer1Size)
252-
253-
for(iter <- 1 to numIterations) {
254-
val (aggSyn0, aggSyn1, _, _) =
255-
// TODO: broadcast temp instead of serializing it directly
256-
// or initialize the model in each executor
257-
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
258-
seqOp = (c, v) => (c, v) match {
288+
289+
for (k <- 1 to numIterations) {
290+
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
291+
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
292+
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
259293
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
260294
var lwc = lastWordCount
261-
var wc = wordCount
295+
var wc = wordCount
262296
if (wordCount - lastWordCount > 10000) {
263297
lwc = wordCount
264-
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
298+
// TODO: discount by iteration?
299+
alpha =
300+
startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
265301
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
266302
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
267303
}
268304
wc += sentence.size
269305
var pos = 0
270306
while (pos < sentence.size) {
271307
val word = sentence(pos)
272-
// TODO: fix random seed
273-
val b = Random.nextInt(window)
308+
val b = random.nextInt(window)
274309
// Train Skip-gram
275310
var a = b
276311
while (a < window * 2 + 1 - b) {
@@ -280,7 +315,7 @@ class Word2Vec(
280315
val lastWord = sentence(c)
281316
val l1 = lastWord * layer1Size
282317
val neu1e = new Array[Float](layer1Size)
283-
// Hierarchical softmax
318+
// Hierarchical softmax
284319
var d = 0
285320
while (d < bcVocab.value(word).codeLen) {
286321
val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +338,44 @@ class Word2Vec(
303338
pos += 1
304339
}
305340
(syn0, syn1, lwc, wc)
306-
},
307-
combOp = (c1, c2) => (c1, c2) match {
308-
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
309-
val n = syn0_1.length
310-
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311-
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312-
blas.sscal(n, weight1, syn0_1, 1)
313-
blas.sscal(n, weight1, syn1_1, 1)
314-
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
315-
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
316-
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
317-
})
341+
}
342+
Iterator(model)
343+
}
344+
val (aggSyn0, aggSyn1, _, _) =
345+
partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
346+
val n = syn0_1.length
347+
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
348+
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
349+
blas.sscal(n, weight1, syn0_1, 1)
350+
blas.sscal(n, weight1, syn1_1, 1)
351+
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
352+
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
353+
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
354+
}
318355
syn0Global = aggSyn0
319356
syn1Global = aggSyn1
320357
}
321358
newSentences.unpersist()
322359

323-
val wordMap = new Array[(String, Array[Float])](vocabSize)
360+
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
324361
var i = 0
325362
while (i < vocabSize) {
326363
val word = bcVocab.value(i).word
327364
val vector = new Array[Float](layer1Size)
328365
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
329-
wordMap(i) = (word, vector)
366+
word2VecMap += word -> vector
330367
i += 1
331368
}
332-
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
333-
.partitionBy(new HashPartitioner(modelPartitionNum))
334-
.persist(StorageLevel.MEMORY_AND_DISK)
335-
336-
new Word2VecModel(modelRDD)
369+
370+
new Word2VecModel(word2VecMap.toMap)
337371
}
338372
}
339373

340374
/**
341375
* Word2Vec model
342-
*/
343-
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
376+
*/
377+
class Word2VecModel private[mllib] (
378+
private val model: Map[String, Array[Float]]) extends Serializable {
344379

345380
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
346381
require(v1.length == v2.length, "Vectors should have the same length")
@@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
357392
* @return vector representation of word
358393
*/
359394
def transform(word: String): Vector = {
360-
val result = model.lookup(word)
361-
if (result.isEmpty) {
362-
throw new IllegalStateException(s"$word not in vocabulary")
395+
model.get(word) match {
396+
case Some(vec) =>
397+
Vectors.dense(vec.map(_.toDouble))
398+
case None =>
399+
throw new IllegalStateException(s"$word not in vocabulary")
363400
}
364-
else Vectors.dense(result(0).map(_.toDouble))
365401
}
366402

367403
/**
@@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
392428
*/
393429
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
394430
require(num > 0, "Number of similar words should > 0")
395-
val topK = model.map { case(w, vec) =>
396-
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
397-
.sortByKey(ascending = false)
398-
.take(num + 1)
399-
.map(_.swap)
400-
.tail
401-
402-
topK
403-
}
404-
}
405-
406-
object Word2Vec{
407-
/**
408-
* Train Word2Vec model
409-
* @param input RDD of words
410-
* @param size vector dimension
411-
* @param startingAlpha initial learning rate
412-
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
413-
* @param numIterations number of iterations, should be smaller than or equal to parallelism
414-
* @return Word2Vec model
415-
*/
416-
def train[S <: Iterable[String]](
417-
input: RDD[S],
418-
size: Int,
419-
startingAlpha: Double,
420-
parallelism: Int = 1,
421-
numIterations:Int = 1): Word2VecModel = {
422-
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
431+
// TODO: optimize top-k
432+
val fVector = vector.toArray.map(_.toFloat)
433+
model.mapValues(vec => cosineSimilarity(fVector, vec))
434+
.toSeq
435+
.sortBy(- _._2)
436+
.take(num + 1)
437+
.tail
438+
.toArray
423439
}
424440
}

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
3030
val localDoc = Seq(sentence, sentence)
3131
val doc = sc.parallelize(localDoc)
3232
.map(line => line.split(" ").toSeq)
33-
val size = 10
34-
val startingAlpha = 0.025
35-
val window = 2
36-
val minCount = 2
37-
val num = 2
38-
39-
val model = Word2Vec.train(doc, size, startingAlpha)
33+
val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
4034
val syms = model.findSynonyms("a", 2)
41-
assert(syms.length == num)
35+
assert(syms.length == 2)
4236
assert(syms(0)._1 == "b")
4337
assert(syms(1)._1 == "c")
4438
}
4539

46-
4740
test("Word2VecModel") {
4841
val num = 2
49-
val localModel = Seq(
42+
val word2VecMap = Map(
5043
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
5144
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
5245
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
5346
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
5447
)
55-
val model = new Word2VecModel(sc.parallelize(localModel, 2))
48+
val model = new Word2VecModel(word2VecMap)
5649
val syms = model.findSynonyms("china", num)
5750
assert(syms.length == num)
5851
assert(syms(0)._1 == "taiwan")

0 commit comments

Comments
 (0)