@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared._
2323import org .apache .spark .ml .util .SchemaUtils
2424import org .apache .spark .ml .{Estimator , Model }
2525import org .apache .spark .mllib .feature
26+ import org .apache .spark .mllib .linalg .BLAS ._
2627import org .apache .spark .mllib .linalg .{VectorUDT , Vectors }
2728import org .apache .spark .sql .functions ._
2829import org .apache .spark .sql .types ._
@@ -170,15 +171,21 @@ class Word2VecModel private[ml] (
170171 transformSchema(dataset.schema, paramMap, logging = true )
171172 val map = extractParamMap(paramMap)
172173 val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
173- val word2Vec = udf { v : Seq [String ] =>
174- if (v .size == 0 ) {
175- Vectors .zeros (map(vectorSize))
174+ val word2Vec = udf { sentence : Seq [String ] =>
175+ if (sentence .size == 0 ) {
176+ Vectors .sparse (map(vectorSize), Array .empty[ Int ], Array .empty[ Double ] )
176177 } else {
177- Vectors .dense(
178- v.map(bWordVectors.value.getVectors).foldLeft(Array .fill[Double ](map(vectorSize))(0 )) {
179- (cum, vec) => cum.zip(vec).map(x => x._1 + x._2)
180- }.map(_ / v.size)
181- )
178+ val cum = Vectors .zeros(map(vectorSize))
179+ val model = bWordVectors.value.getVectors
180+ for (word <- sentence) {
181+ if (model.contains(word)) {
182+ axpy(1.0 , bWordVectors.value.transform(word), cum)
183+ } else {
184+ // pass words which not belong to model
185+ }
186+ }
187+ scal(1.0 / sentence.size, cum)
188+ cum
182189 }
183190 }
184191 dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
0 commit comments