Skip to content

Commit c594095

Browse files
committed
add word2vec transformer
1 parent c0c0ba6 commit c594095

File tree

1 file changed

+196
-0
lines changed

1 file changed

+196
-0
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import com.sun.tools.javac.code.TypeTag
21+
import org.apache.spark.annotation.AlphaComponent
22+
import org.apache.spark.annotation.AlphaComponent
23+
import org.apache.spark.ml.Estimator
24+
import org.apache.spark.ml.Model
25+
import org.apache.spark.ml._
26+
import org.apache.spark.ml.param.HasInputCol
27+
import org.apache.spark.ml.param.HasOutputCol
28+
import org.apache.spark.ml.param.ParamMap
29+
import org.apache.spark.ml.param.Params
30+
import org.apache.spark.ml.param._
31+
import org.apache.spark.mllib.feature
32+
import org.apache.spark.mllib.feature
33+
import org.apache.spark.mllib.feature.Word2Vec
34+
import org.apache.spark.mllib.linalg.Vector
35+
import org.apache.spark.mllib.linalg.VectorUDT
36+
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
37+
import org.apache.spark.sql.DataFrame
38+
import org.apache.spark.sql.Row
39+
import org.apache.spark.sql._
40+
import org.apache.spark.sql.functions._
41+
import org.apache.spark.sql.functions._
42+
import org.apache.spark.sql.types._
43+
import org.apache.spark.util.Utils
44+
45+
import scala.reflect.ClassTag
46+
47+
/**
48+
* Params for [[StandardScaler]] and [[StandardScalerModel]].
49+
*/
50+
private[feature] trait Word2VecParams extends Params with HasInputCol {
51+
val vectorSize = new IntParam(this, "vectorSize", "", Some(100))
52+
def getVectorSize: Int = get(vectorSize)
53+
54+
val learningRate = new DoubleParam(this, "learningRate", "", Some(0.025))
55+
def getLearningRate: Double = get(learningRate)
56+
57+
val numPartitions = new IntParam(this, "numPartitions", "", Some(1))
58+
def getNumPartitions: Int = get(numPartitions)
59+
60+
val numIterations = new IntParam(this, "numIterations", "", Some(1))
61+
def getNumIterations: Int = get(numIterations)
62+
63+
val seed = new LongParam(this, "seed", "", Some(Utils.random.nextLong()))
64+
def getSeed: Long = get(seed)
65+
66+
val minCount = new IntParam(this, "minCount", "", Some(5))
67+
def getMinCount: Int = get(minCount)
68+
69+
val synonymsCol = new Param[String](this, "synonymsCol", "Synonyms column name")
70+
def getSynonymsCol: String = get(synonymsCol)
71+
72+
val codeCol = new Param[String](this, "codeCol", "Code column name")
73+
def getCodeCol: String = get(codeCol)
74+
75+
val numSynonyms = new IntParam(this, "numSynonyms", "number of synonyms to find", Some(0))
76+
def getNumSynonyms: Int = get(numSynonyms)
77+
78+
type S <: Iterable[String]
79+
}
80+
81+
/**
82+
* :: AlphaComponent ::
83+
* Standardizes features by removing the mean and scaling to unit variance using column summary
84+
* statistics on the samples in the training set.
85+
*/
86+
@AlphaComponent
87+
class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
88+
89+
/** @group setParam */
90+
def setInputCol(value: String): this.type = set(inputCol, value)
91+
def setVectorSize(value: Int) = set(vectorSize, value)
92+
def setLearningRate(value: Double) = set(learningRate, value)
93+
def setNumPartitions(value: Int) = set(numPartitions, value)
94+
def setNumIterations(value: Int) = set(numIterations, value)
95+
def setSeed(value: Long) = set(seed, value)
96+
def setMinCount(value: Int) = set(minCount, value)
97+
98+
override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = {
99+
transformSchema(dataset.schema, paramMap, logging = true)
100+
val map = this.paramMap ++ paramMap
101+
val input = dataset.select(map(inputCol)).map { case Row(v: S) => v }
102+
val wordVectors = new feature.Word2Vec()
103+
.setLearningRate(map(learningRate))
104+
.setMinCount(map(minCount))
105+
.setNumIterations(map(numIterations))
106+
.setNumPartitions(map(numPartitions))
107+
.setSeed(map(seed))
108+
.setVectorSize(map(vectorSize))
109+
.fit(input)
110+
val model = new Word2VecModel(this, map, wordVectors)
111+
Params.inheritValues(map, this, model)
112+
model
113+
}
114+
115+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
116+
val map = this.paramMap ++ paramMap
117+
val inputType = schema(map(inputCol)).dataType
118+
require(inputType.isInstanceOf[S],
119+
s"Input column ${map(inputCol)} must be a Iterable[String] column")
120+
schema
121+
}
122+
}
123+
124+
/**
125+
* :: AlphaComponent ::
126+
* Model fitted by [[StandardScaler]].
127+
*/
128+
@AlphaComponent
129+
class Word2VecModel private[ml] (
130+
override val parent: Word2Vec,
131+
override val fittingParamMap: ParamMap,
132+
wordVectors: feature.Word2VecModel)
133+
extends Model[Word2VecModel] with Word2VecParams {
134+
135+
/** @group setParam */
136+
def setInputCol(value: String): this.type = set(inputCol, value)
137+
138+
/** @group setParam */
139+
def setSynonymsCol(value: String): this.type = set(synonymsCol, value)
140+
141+
/** @group setParam */
142+
def setCodeCol(value: String): this.type = set(codeCol, value)
143+
144+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
145+
transformSchema(dataset.schema, paramMap, logging = true)
146+
val map = this.paramMap ++ paramMap
147+
148+
var tmpData = dataset
149+
var numColsOutput = 0
150+
151+
if (map(codeCol) != "") {
152+
val word2vec: String => Vector = (word) => wordVectors.transform(word)
153+
tmpData = tmpData.withColumn(map(codeCol), callUDF(word2vec, new VectorUDT, col(map(inputCol))))
154+
numColsOutput += 1
155+
}
156+
157+
if (map(synonymsCol) != "" & map(numSynonyms) > 0) {
158+
val findSynonyms = udf { (word: String) => wordVectors.findSynonyms(word, map(numSynonyms)) : Array[(String, Double)] }
159+
tmpData = tmpData.withColumn(map(synonymsCol), findSynonyms(col(map(inputCol))))
160+
numColsOutput += 1
161+
}
162+
163+
if (numColsOutput == 0) {
164+
this.logWarning(s"$uid: Word2VecModel.transform() was called as NOOP" +
165+
s" since no output columns were set.")
166+
}
167+
168+
tmpData
169+
}
170+
171+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
172+
val map = this.paramMap ++ paramMap
173+
174+
val inputType = schema(map(inputCol)).dataType
175+
require(inputType.isInstanceOf[String],
176+
s"Input column ${map(inputCol)} must be a string column")
177+
178+
var outputFields = schema.fields
179+
180+
if (map(codeCol) != "") {
181+
require(!schema.fieldNames.contains(map(codeCol)),
182+
s"Output column ${map(codeCol)} already exists.")
183+
outputFields = outputFields :+ StructField(map(codeCol), new VectorUDT, false)
184+
}
185+
186+
if (map(synonymsCol) != "") {
187+
require(!schema.fieldNames.contains(map(synonymsCol)),
188+
s"Output column ${map(synonymsCol)} already exists.")
189+
require(map(numSynonyms) > 0,
190+
s"Number of synonyms should larger than 0")
191+
outputFields = outputFields :+ StructField(map(synonymsCol), , false)
192+
}
193+
194+
StructType(outputFields)
195+
}
196+
}

0 commit comments

Comments
 (0)