Skip to content

Commit 3bc2cbd

Browse files
committed
change foldLeft to for loop and use blas
1 parent 5dd4ee7 commit 3bc2cbd

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared._
2323
import org.apache.spark.ml.util.SchemaUtils
2424
import org.apache.spark.ml.{Estimator, Model}
2525
import org.apache.spark.mllib.feature
26+
import org.apache.spark.mllib.linalg.BLAS._
2627
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
2728
import org.apache.spark.sql.functions._
2829
import org.apache.spark.sql.types._
@@ -170,15 +171,21 @@ class Word2VecModel private[ml] (
170171
transformSchema(dataset.schema, paramMap, logging = true)
171172
val map = extractParamMap(paramMap)
172173
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
173-
val word2Vec = udf { v: Seq[String] =>
174-
if (v.size == 0) {
175-
Vectors.zeros(map(vectorSize))
174+
val word2Vec = udf { sentence: Seq[String] =>
175+
if (sentence.size == 0) {
176+
Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double])
176177
} else {
177-
Vectors.dense(
178-
v.map(bWordVectors.value.getVectors).foldLeft(Array.fill[Double](map(vectorSize))(0)) {
179-
(cum, vec) => cum.zip(vec).map(x => x._1 + x._2)
180-
}.map(_ / v.size)
181-
)
178+
val cum = Vectors.zeros(map(vectorSize))
179+
val model = bWordVectors.value.getVectors
180+
for (word <- sentence) {
181+
if (model.contains(word)) {
182+
axpy(1.0, bWordVectors.value.transform(word), cum)
183+
} else {
184+
// pass words which not belong to model
185+
}
186+
}
187+
scal(1.0 / sentence.size, cum)
188+
cum
182189
}
183190
}
184191
dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))

0 commit comments

Comments
 (0)