1818package org .apache .spark .ml .feature
1919
2020import org .apache .spark .annotation .AlphaComponent
21+ import org .apache .spark .ml .param ._
22+ import org .apache .spark .ml .param .shared ._
23+ import org .apache .spark .ml .util .SchemaUtils
2124import org .apache .spark .ml .{Estimator , Model }
22- import org .apache .spark .ml .param .{HasInputCol , ParamMap , Params , _ }
2325import org .apache .spark .mllib .feature
24- import org .apache .spark .mllib .linalg .{Vector , VectorUDT }
25- import org .apache .spark .sql .{DataFrame , Row }
26+ import org .apache .spark .mllib .linalg .{VectorUDT , Vectors }
2627import org .apache .spark .sql .functions ._
2728import org .apache .spark .sql .types ._
29+ import org .apache .spark .sql .{DataFrame , Row }
2830import org .apache .spark .util .Utils
2931
3032/**
3133 * Params for [[Word2Vec ]] and [[Word2VecModel ]].
3234 */
33- private [feature] trait Word2VecParams extends Params
34- with HasInputCol with HasMaxIter with HasLearningRate {
35+ private [feature] trait Word2VecBase extends Params
36+ with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize {
3537
3638 /**
3739 * The dimension of the code that you want to transform from words.
3840 */
39- val vectorSize = new IntParam (
40- this , " vectorSize" , " the dimension of codes after transforming from words" , Some (100 ))
41+ final val vectorSize = new IntParam (
42+ this , " vectorSize" , " the dimension of codes after transforming from words" )
43+
44+ setDefault(vectorSize -> 100 )
4145
4246 /** @group getParam */
43- def getVectorSize : Int = get (vectorSize)
47+ def getVectorSize : Int = getOrDefault (vectorSize)
4448
4549 /**
4650 * Number of partitions for sentences of words.
4751 */
48- val numPartitions = new IntParam (
49- this , " numPartitions" , " number of partitions for sentences of words" , Some (1 ))
52+ final val numPartitions = new IntParam (
53+ this , " numPartitions" , " number of partitions for sentences of words" )
54+
55+ setDefault(numPartitions -> 1 )
5056
5157 /** @group getParam */
52- def getNumPartitions : Int = get (numPartitions)
58+ def getNumPartitions : Int = getOrDefault (numPartitions)
5359
5460 /**
5561 * A random seed to random an initial vector.
5662 */
57- val seed = new LongParam (
58- this , " seed" , " a random seed to random an initial vector" , Some (Utils .random.nextLong()))
63+ final val seed = new LongParam (this , " seed" , " a random seed to random an initial vector" )
64+
65+ setDefault(seed -> Utils .random.nextLong())
5966
6067 /** @group getParam */
61- def getSeed : Long = get (seed)
68+ def getSeed : Long = getOrDefault (seed)
6269
6370 /**
6471 * The minimum count of words that can be kept in training set.
6572 */
66- val minCount = new IntParam (
67- this , " minCount" , " the minimum count of words to filter words" , Some (5 ))
73+ final val minCount = new IntParam (this , " minCount" , " the minimum count of words to filter words" )
6874
69- /** @group getParam */
70- def getMinCount : Int = get(minCount)
71-
72- /**
73- * The column name of the output column - synonyms.
74- */
75- val synonymsCol = new Param [String ](this , " synonymsCol" , " Synonyms column name" )
75+ setDefault(minCount -> 5 )
7676
7777 /** @group getParam */
78- def getSynonymsCol : String = get(synonymsCol )
78+ def getMinCount : Int = getOrDefault(minCount )
7979
8080 /**
81- * The column name of the output column - code .
81+ * Validate and transform the input schema .
8282 */
83- val codeCol = new Param [String ](this , " codeCol" , " Code column name" )
84-
85- /** @group getParam */
86- def getCodeCol : String = get(codeCol)
87-
88- /**
89- * The number of synonyms that you want to have.
90- */
91- val numSynonyms = new IntParam (this , " numSynonyms" , " number of synonyms to find" , Some (0 ))
92-
93- /** @group getParam */
94- def getNumSynonyms : Int = get(numSynonyms)
83+ protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
84+ val map = extractParamMap(paramMap)
85+ SchemaUtils .checkColumnType(schema, map(inputCol), new ArrayType (new StringType , false ))
86+ SchemaUtils .appendColumn(schema, map(outputCol), new VectorUDT )
87+ }
9588}
9689
9790/**
@@ -100,16 +93,19 @@ private[feature] trait Word2VecParams extends Params
10093 * natural language processing or machine learning process.
10194 */
10295@ AlphaComponent
103- class Word2Vec extends Estimator [Word2VecModel ] with Word2VecParams {
96+ final class Word2Vec extends Estimator [Word2VecModel ] with Word2VecBase {
10497
10598 /** @group setParam */
10699 def setInputCol (value : String ): this .type = set(inputCol, value)
107100
101+ /** @group setParam */
102+ def setOutputCol (value : String ): this .type = set(outputCol, value)
103+
108104 /** @group setParam */
109105 def setVectorSize (value : Int ) = set(vectorSize, value)
110106
111107 /** @group setParam */
112- def setLearningRate (value : Double ) = set(learningRate , value)
108+ def setStepSize (value : Double ) = set(stepSize , value)
113109
114110 /** @group setParam */
115111 def setNumPartitions (value : Int ) = set(numPartitions, value)
@@ -125,10 +121,10 @@ class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
125121
126122 override def fit (dataset : DataFrame , paramMap : ParamMap ): Word2VecModel = {
127123 transformSchema(dataset.schema, paramMap, logging = true )
128- val map = this . paramMap ++ paramMap
124+ val map = extractParamMap( paramMap)
129125 val input = dataset.select(map(inputCol)).map { case Row (v : Seq [String ]) => v }
130126 val wordVectors = new feature.Word2Vec ()
131- .setLearningRate(map(learningRate ))
127+ .setLearningRate(map(stepSize ))
132128 .setMinCount(map(minCount))
133129 .setNumIterations(map(maxIter))
134130 .setNumPartitions(map(numPartitions))
@@ -141,11 +137,7 @@ class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
141137 }
142138
143139 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
144- val map = this .paramMap ++ paramMap
145- val inputType = schema(map(inputCol)).dataType
146- require(inputType.isInstanceOf [ArrayType ],
147- s " Input column ${map(inputCol)} must be a Iterable[String] column " )
148- schema
140+ validateAndTransformSchema(schema, paramMap)
149141 }
150142}
151143
@@ -158,81 +150,30 @@ class Word2VecModel private[ml] (
158150 override val parent : Word2Vec ,
159151 override val fittingParamMap : ParamMap ,
160152 wordVectors : feature.Word2VecModel )
161- extends Model [Word2VecModel ] with Word2VecParams {
153+ extends Model [Word2VecModel ] with Word2VecBase {
162154
163155 /** @group setParam */
164156 def setInputCol (value : String ): this .type = set(inputCol, value)
165157
166- /** @group setParam */
167- def setSynonymsCol (value : String ): this .type = set(synonymsCol, value)
168-
169- /** @group setParam */
170- def setNumSynonyms (value : Int ): this .type = set(numSynonyms, value)
171-
172- /** @group setParam */
173- def setCodeCol (value : String ): this .type = set(codeCol, value)
158+ /** @group setParam */
159+ def setOutputCol (value : String ): this .type = set(outputCol, value)
174160
175161 /**
176- * The transforming process of `Word2Vec` model has two approaches - 1. Transform a word of
177- * `String` into a code of `Vector`; 2. Find n (given by you) synonyms of a given word.
178- *
179- * Note. Currently we only support finding synonyms for word of `String`, not `Vector`.
162+ * Transform a sentence column to a vector column to represent the whole sentence.
180163 */
181164 override def transform (dataset : DataFrame , paramMap : ParamMap ): DataFrame = {
182165 transformSchema(dataset.schema, paramMap, logging = true )
183- val map = this .paramMap ++ paramMap
184-
185- var tmpData = dataset
186- var numColsOutput = 0
187-
188- if (map(codeCol) != " " ) {
189- val word2vec : String => Vector = (word) => wordVectors.transform(word)
190- tmpData = tmpData.withColumn(map(codeCol),
191- callUDF(word2vec, new VectorUDT , col(map(inputCol))))
192- numColsOutput += 1
193- }
194-
195- if (map(synonymsCol) != " " & map(numSynonyms) > 0 ) {
196- // TODO We will add finding synonyms for code of `Vector`.
197- val findSynonyms = udf { (word : String ) =>
198- wordVectors.findSynonyms(word, map(numSynonyms)).toMap : Map [String , Double ]
166+ val map = extractParamMap(paramMap)
167+ val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
168+ val word2Vec = udf { v : Seq [String ] =>
169+ v.map(bWordVectors.value.transform).foldLeft(Vectors .zeros(map(vectorSize))) { (cum, vec) =>
170+ Vectors .dense(cum.toArray.zip(vec.toArray).map(x => x._1 + x._2))
199171 }
200- tmpData = tmpData.withColumn(map(synonymsCol), findSynonyms(col(map(inputCol))))
201- numColsOutput += 1
202172 }
203-
204- if (numColsOutput == 0 ) {
205- this .logWarning(s " $uid: Word2VecModel.transform() was called as NOOP " +
206- s " since no output columns were set. " )
207- }
208-
209- tmpData
173+ dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
210174 }
211175
212176 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
213- val map = this .paramMap ++ paramMap
214-
215- val inputType = schema(map(inputCol)).dataType
216- require(inputType.isInstanceOf [StringType ],
217- s " Input column ${map(inputCol)} must be a string column " )
218-
219- var outputFields = schema.fields
220-
221- if (map(codeCol) != " " ) {
222- require(! schema.fieldNames.contains(map(codeCol)),
223- s " Output column ${map(codeCol)} already exists. " )
224- outputFields = outputFields :+ StructField (map(codeCol), new VectorUDT , nullable = false )
225- }
226-
227- if (map(synonymsCol) != " " ) {
228- require(! schema.fieldNames.contains(map(synonymsCol)),
229- s " Output column ${map(synonymsCol)} already exists. " )
230- require(map(numSynonyms) > 0 ,
231- s " Number of synonyms should larger than 0 " )
232- outputFields = outputFields :+
233- StructField (map(synonymsCol), MapType (StringType , DoubleType ), nullable = false )
234- }
235-
236- StructType (outputFields)
177+ validateAndTransformSchema(schema, paramMap)
237178 }
238179}
0 commit comments