Skip to content

Commit 54d9575

Browse files
jinntrancemengxr
authored andcommitted
[MLLIB] SPARK-4846: throw a RuntimeException and give users hints to increase the minCount
When the vocabSize\*vectorSize is larger than Int.MaxValue/8, we try to throw a RuntimeException. Because under this circumstance it would definitely throw an OOM when allocating memory to serialize the arrays syn0Global&syn1Global. syn0Global&syn1Global are float arrays. Serializing them should need a byte array of more than 8 times of syn0Global's size. Also if we catch an OOM even if vocabSize\*vectorSize is less than Int.MaxValue/8, we should give users hints to increase the minCount or decrease the vectorSize. Author: Joseph J.C. Tang <[email protected]> Closes apache#4247 from jinntrance/w2v-fix and squashes the following commits: b5eb71f [Joseph J.C. Tang] throw a RuntimeException and give users hints regarding the vectorSize&minCount
1 parent 254eaa4 commit 54d9575

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,13 @@ class Word2Vec extends Serializable with Logging {
290290

291291
val newSentences = sentences.repartition(numPartitions).cache()
292292
val initRandom = new XORShiftRandom(seed)
293+
294+
if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) {
295+
throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" +
296+
" to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " +
297+
"which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.")
298+
}
299+
293300
val syn0Global =
294301
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
295302
val syn1Global = new Array[Float](vocabSize * vectorSize)

0 commit comments

Comments
 (0)