Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 102 additions & 86 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.{HashPartitioner, Logging}

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

/**
* Entry in vocabulary
Expand Down Expand Up @@ -58,29 +59,63 @@ private case class VocabWord(
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality.
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
*/
@Experimental
class Word2Vec(
val size: Int,
val startingAlpha: Double,
val parallelism: Int,
val numIterations: Int) extends Serializable with Logging {
class Word2Vec extends Serializable with Logging {

private var vectorSize = 100
private var startingAlpha = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()

/**
* Sets vector size (default: 100).
*/
def setVectorSize(vectorSize: Int): this.type = {
this.vectorSize = vectorSize
this
}

/**
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
this.startingAlpha = learningRate
this
}

/**
* Word2Vec with a single thread.
* Sets number of partitions (default: 1). Use a small number for accuracy.
*/
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
def setNumPartitions(numPartitions: Int): this.type = {
require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
this.numPartitions = numPartitions
this
}

/**
* Sets number of iterations (default: 1), which should be smaller than or equal to number of
* partitions.
*/
def setNumIterations(numIterations: Int): this.type = {
this.numIterations = numIterations
this
}

/**
* Sets random seed (default: a random long integer).
*/
def setSeed(seed: Long): this.type = {
this.seed = seed
this
}

private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = size
private val modelPartitionNum = 100
private val layer1Size = vectorSize

/** context words from [-window, window] */
private val window = 5
Expand All @@ -94,12 +129,12 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(words:RDD[String]): Unit = {
private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
x._1,
x._2,
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
Expand Down Expand Up @@ -245,32 +280,32 @@ class Word2Vec(
}
}

val newSentences = sentences.repartition(parallelism).cache()
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
var syn0Global =
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
var syn1Global = new Array[Float](vocabSize * layer1Size)

for(iter <- 1 to numIterations) {
val (aggSyn0, aggSyn1, _, _) =
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
seqOp = (c, v) => (c, v) match {

for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
// TODO: discount by iteration?
alpha =
startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
// TODO: fix random seed
val b = Random.nextInt(window)
val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
Expand All @@ -280,7 +315,7 @@ class Word2Vec(
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Float](layer1Size)
// Hierarchical softmax
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
Expand All @@ -303,44 +338,44 @@ class Word2Vec(
pos += 1
}
(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})
}
Iterator(model)
}
val (aggSyn0, aggSyn1, _, _) =
partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
}
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()

val wordMap = new Array[(String, Array[Float])](vocabSize)
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
word2VecMap += word -> vector
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum))
.persist(StorageLevel.MEMORY_AND_DISK)

new Word2VecModel(modelRDD)

new Word2VecModel(word2VecMap.toMap)
}
}

/**
* Word2Vec model
*/
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
*/
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable {

private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
Expand All @@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
* @return vector representation of word
*/
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) {
throw new IllegalStateException(s"$word not in vocabulary")
model.get(word) match {
case Some(vec) =>
Vectors.dense(vec.map(_.toDouble))
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
else Vectors.dense(result(0).map(_.toDouble))
}

/**
Expand Down Expand Up @@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map { case(w, vec) =>
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
.sortByKey(ascending = false)
.take(num + 1)
.map(_.swap)
.tail

topK
}
}

object Word2Vec{
/**
* Train Word2Vec model
* @param input RDD of words
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
* @param numIterations number of iterations, should be smaller than or equal to parallelism
* @return Word2Vec model
*/
def train[S <: Iterable[String]](
input: RDD[S],
size: Int,
startingAlpha: Double,
parallelism: Int = 1,
numIterations:Int = 1): Word2VecModel = {
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
model.mapValues(vec => cosineSimilarity(fVector, vec))
.toSeq
.sortBy(- _._2)
.take(num + 1)
.tail
.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
val localDoc = Seq(sentence, sentence)
val doc = sc.parallelize(localDoc)
.map(line => line.split(" ").toSeq)
val size = 10
val startingAlpha = 0.025
val window = 2
val minCount = 2
val num = 2

val model = Word2Vec.train(doc, size, startingAlpha)
val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
val syms = model.findSynonyms("a", 2)
assert(syms.length == num)
assert(syms.length == 2)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
}


test("Word2VecModel") {
val num = 2
val localModel = Seq(
val word2VecMap = Map(
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(sc.parallelize(localModel, 2))
val model = new Word2VecModel(word2VecMap)
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")
Expand Down