Skip to content

Commit b9a7383

Browse files
committed
cache words RDD in fit
1 parent 89490bf commit b9a7383

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class Word2Vec extends Serializable with Logging {
248248
*/
249249
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
250250

251-
val words = dataset.flatMap(x => x)
251+
val words = dataset.flatMap(x => x).cache()
252252

253253
learnVocab(words)
254254

@@ -281,7 +281,9 @@ class Word2Vec extends Serializable with Logging {
281281
}
282282
}
283283

284-
val newSentences = sentences.repartition(numPartitions).cache()
284+
val newSentences = sentences.repartition(numPartitions)
285+
words.unpersist()
286+
newSentences.cache()
285287
val initRandom = new XORShiftRandom(seed)
286288
val syn0Global =
287289
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)

0 commit comments

Comments
 (0)