11/*
2- * Licensed to the Apache Software Foundation (ASF) under one or more
3- * contributor license agreements. See the NOTICE file distributed with
4- * this work for additional information regarding copyright ownership.
5- * The ASF licenses this file to You under the Apache License, Version 2.0
6- * Add a comment to this line
7- * (the "License"); you may not use this file except in compliance with
8- * the License. You may obtain a copy of the License at
9- *
10- * http://www.apache.org/licenses/LICENSE-2.0
11- *
12- * Unless required by applicable law or agreed to in writing, software
13- * distributed under the License is distributed on an "AS IS" BASIS,
14- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15- * See the License for the specific language governing permissions and
16- * limitations under the License.
17- */
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
1817
1918package org .apache .spark .mllib .feature
2019
21- import scala .util .{ Random => Random }
20+ import scala .util .Random
2221import scala .collection .mutable .ArrayBuffer
2322import scala .collection .mutable
2423
2524import com .github .fommil .netlib .BLAS .{getInstance => blas }
2625
27- import org .apache .spark ._
26+ import org .apache .spark .annotation .Experimental
27+ import org .apache .spark .Logging
2828import org .apache .spark .rdd ._
2929import org .apache .spark .SparkContext ._
30- import org .apache .spark .mllib .linalg .Vector
30+ import org .apache .spark .mllib .linalg .{ Vector , Vectors }
3131import org .apache .spark .HashPartitioner
3232
3333/**
@@ -42,8 +42,27 @@ private case class VocabWord(
4242)
4343
4444/**
45- * Vector representation of word
45+ * :: Experimental ::
46+ * Word2Vec creates vector representation of words in a text corpus.
47+ * The algorithm first constructs a vocabulary from the corpus
48+ * and then learns vector representation of words in the vocabulary.
49+ * The vector representation can be used as features in
50+ * natural language processing and machine learning algorithms.
51+ *
52+ * We used skip-gram model in our implementation and hierarchical softmax
53+ * method to train the model.
54+ *
55+ * For original C implementation, see https://code.google.com/p/word2vec/
56+ * For research papers, see
57+ * Efficient Estimation of Word Representations in Vector Space
58+ * and
59+ * Distributed Representations of Words and Phrases and their Compositionality
60+ * @param size vector dimension
61+ * @param startingAlpha initial learning rate
62+ * @param window context words from [-window, window]
63+ * @param minCount minimum frequncy to consider a vocabulary word
4664 */
65+ @ Experimental
4766class Word2Vec (
4867 val size : Int ,
4968 val startingAlpha : Double ,
@@ -64,11 +83,15 @@ class Word2Vec(
6483 private var vocabHash = mutable.HashMap .empty[String , Int ]
6584 private var alpha = startingAlpha
6685
67- private def learnVocab (dataset : RDD [String ]) {
68- vocab = dataset.flatMap(line => line.split(" " ))
69- .map(w => (w, 1 ))
86+ private def learnVocab (words: RDD [String ]) {
87+ vocab = words.map(w => (w, 1 ))
7088 .reduceByKey(_ + _)
71- .map(x => VocabWord (x._1, x._2, new Array [Int ](MAX_CODE_LENGTH ), new Array [Int ](MAX_CODE_LENGTH ), 0 ))
89+ .map(x => VocabWord (
90+ x._1,
91+ x._2,
92+ new Array [Int ](MAX_CODE_LENGTH ),
93+ new Array [Int ](MAX_CODE_LENGTH ),
94+ 0 ))
7295 .filter(_.cn >= minCount)
7396 .collect()
7497 .sortWith((a, b)=> a.cn > b.cn)
@@ -172,15 +195,16 @@ class Word2Vec(
172195 }
173196
174197 /**
175- * Computes the vector representation of each word in
176- * vocabulary
177- * @param dataset an RDD of strings
198+ * Computes the vector representation of each word in vocabulary.
199+ * @param dataset an RDD of words
178200 * @return a Word2VecModel
179201 */
180202
181- def fit (dataset: RDD [String ]): Word2VecModel = {
203+ def fit [ S <: Iterable [ String ]] (dataset: RDD [S ]): Word2VecModel = {
182204
183- learnVocab(dataset)
205+ val words = dataset.flatMap(x => x)
206+
207+ learnVocab(words)
184208
185209 createBinaryTree()
186210
@@ -190,9 +214,10 @@ class Word2Vec(
190214 val V = sc.broadcast(vocab)
191215 val VHash = sc.broadcast(vocabHash)
192216
193- val sentences = dataset.flatMap(line => line.split( " " )) .mapPartitions {
217+ val sentences = words .mapPartitions {
194218 iter => { new Iterator [Array [Int ]] {
195219 def hasNext = iter.hasNext
220+
196221 def next = {
197222 var sentence = new ArrayBuffer [Int ]
198223 var sentenceLength = 0
@@ -215,7 +240,8 @@ class Word2Vec(
215240 val newSentences = sentences.repartition(1 ).cache()
216241 val temp = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
217242 val (aggSyn0, _, _, _) =
218- // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
243+ // TODO: broadcast temp instead of serializing it directly
244+ // or initialize the model in each executor
219245 newSentences.aggregate((temp.clone(), new Array [Double ](vocabSize * layer1Size), 0 , 0 ))(
220246 seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
221247 var lwc = lastWordCount
@@ -241,7 +267,7 @@ class Word2Vec(
241267 val lastWord = sentence(c)
242268 val l1 = lastWord * layer1Size
243269 val neu1e = new Array [Double ](layer1Size)
244- // HS
270+ // Hierarchical softmax
245271 var d = 0
246272 while (d < vocab(word).codeLen) {
247273 val l2 = vocab(word).point(d) * layer1Size
@@ -265,11 +291,12 @@ class Word2Vec(
265291 }
266292 (syn0, syn1, lwc, wc)
267293 },
268- combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
269- val n = syn0_1.length
270- blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
271- blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
272- (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
294+ combOp = (c1, c2) => (c1, c2) match {
295+ case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
296+ val n = syn0_1.length
297+ blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
298+ blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
299+ (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
273300 })
274301
275302 val wordMap = new Array [(String , Array [Double ])](vocabSize)
@@ -281,19 +308,18 @@ class Word2Vec(
281308 wordMap(i) = (word, vector)
282309 i += 1
283310 }
284- val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner (modelPartitionNum))
311+ val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
312+ .partitionBy(new HashPartitioner (modelPartitionNum))
285313 new Word2VecModel (modelRDD)
286314 }
287315}
288316
289317/**
290318* Word2Vec model
291319*/
292- class Word2VecModel (val _model : RDD [(String , Array [Double ])]) extends Serializable {
293-
294- val model = _model
320+ class Word2VecModel (private val model : RDD [(String , Array [Double ])]) extends Serializable {
295321
296- private def distance (v1 : Array [Double ], v2 : Array [Double ]): Double = {
322+ private def cosineSimilarity (v1 : Array [Double ], v2 : Array [Double ]): Double = {
297323 require(v1.length == v2.length, " Vectors should have the same length" )
298324 val n = v1.length
299325 val norm1 = blas.dnrm2(n, v1, 1 )
@@ -307,20 +333,20 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
307333 * @param word a word
308334 * @return vector representation of word
309335 */
310-
311- def transform (word : String ): Array [Double ] = {
336+ def transform (word : String ): Vector = {
312337 val result = model.lookup(word)
313- if (result.isEmpty) Array [Double ]()
314- else result(0 )
338+ if (result.isEmpty) {
339+ throw new IllegalStateException (s " ${word} not in vocabulary " )
340+ }
341+ else Vectors .dense(result(0 ))
315342 }
316343
317344 /**
318345 * Transforms an RDD to its vector representation
319346 * @param dataset a an RDD of words
320347 * @return RDD of vector representation
321348 */
322-
323- def transform (dataset : RDD [String ]): RDD [Array [Double ]] = {
349+ def transform (dataset : RDD [String ]): RDD [Vector ] = {
324350 dataset.map(word => transform(word))
325351 }
326352
@@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
332358 */
333359 def findSynonyms (word : String , num : Int ): Array [(String , Double )] = {
334360 val vector = transform(word)
335- if (vector.isEmpty) Array [(String , Double )]()
336- else findSynonyms(vector,num)
361+ findSynonyms(vector,num)
337362 }
338363
339364 /**
340365 * Find synonyms of the vector representation of a word
341366 * @param vector vector representation of a word
342367 * @param num number of synonyms to find
343- * @return array of (word, similarity )
368+ * @return array of (word, cosineSimilarity )
344369 */
345- def findSynonyms (vector : Array [ Double ] , num : Int ): Array [(String , Double )] = {
370+ def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
346371 require(num > 0 , " Number of similar words should > 0" )
347- val topK = model.map(
348- { case (w, vec) => (distance( vector, vec), w)})
372+ val topK = model.map { case (w, vec) =>
373+ (cosineSimilarity( vector.toArray , vec), w) }
349374 .sortByKey(ascending = false )
350375 .take(num + 1 )
351- .map({case (dist, w) => (w, dist)}).drop(1 )
376+ .map(_.swap)
377+ .tail
352378
353379 topK
354380 }
355381}
356382
357- object Word2Vec extends Serializable with Logging {
383+ object Word2Vec {
358384 /**
359385 * Train Word2Vec model
360386 * @param input RDD of words
361- * @param size vectoer dimension
387+ * @param size vector dimension
362388 * @param startingAlpha initial learning rate
363389 * @param window context words from [-window, window]
364390 * @param minCount minimum frequncy to consider a vocabulary word
365391 * @return Word2Vec model
366392 */
367- def train (
368- input : RDD [String ],
393+ def train [ S <: Iterable [ String ]] (
394+ input : RDD [S ],
369395 size : Int ,
370396 startingAlpha : Double ,
371397 window : Int ,
372398 minCount : Int ): Word2VecModel = {
373- new Word2Vec (size,startingAlpha, window, minCount).fit(input)
399+ new Word2Vec (size,startingAlpha, window, minCount).fit[ S ] (input)
374400 }
375401}
0 commit comments