Skip to content

Commit a190f2c

Browse files
committed
fix code style and refine the transform function of word2vec
1 parent 02848fa commit a190f2c

File tree

3 files changed

+68
-109
lines changed

3 files changed

+68
-109
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 49 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -18,80 +18,73 @@
1818
package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.param._
22+
import org.apache.spark.ml.param.shared._
23+
import org.apache.spark.ml.util.SchemaUtils
2124
import org.apache.spark.ml.{Estimator, Model}
22-
import org.apache.spark.ml.param.{HasInputCol, ParamMap, Params, _}
2325
import org.apache.spark.mllib.feature
24-
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
25-
import org.apache.spark.sql.{DataFrame, Row}
26+
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
2627
import org.apache.spark.sql.functions._
2728
import org.apache.spark.sql.types._
29+
import org.apache.spark.sql.{DataFrame, Row}
2830
import org.apache.spark.util.Utils
2931

3032
/**
3133
* Params for [[Word2Vec]] and [[Word2VecModel]].
3234
*/
33-
private[feature] trait Word2VecParams extends Params
34-
with HasInputCol with HasMaxIter with HasLearningRate {
35+
private[feature] trait Word2VecBase extends Params
36+
with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize {
3537

3638
/**
3739
* The dimension of the code that you want to transform from words.
3840
*/
39-
val vectorSize = new IntParam(
40-
this, "vectorSize", "the dimension of codes after transforming from words", Some(100))
41+
final val vectorSize = new IntParam(
42+
this, "vectorSize", "the dimension of codes after transforming from words")
43+
44+
setDefault(vectorSize -> 100)
4145

4246
/** @group getParam */
43-
def getVectorSize: Int = get(vectorSize)
47+
def getVectorSize: Int = getOrDefault(vectorSize)
4448

4549
/**
4650
* Number of partitions for sentences of words.
4751
*/
48-
val numPartitions = new IntParam(
49-
this, "numPartitions", "number of partitions for sentences of words", Some(1))
52+
final val numPartitions = new IntParam(
53+
this, "numPartitions", "number of partitions for sentences of words")
54+
55+
setDefault(numPartitions -> 1)
5056

5157
/** @group getParam */
52-
def getNumPartitions: Int = get(numPartitions)
58+
def getNumPartitions: Int = getOrDefault(numPartitions)
5359

5460
/**
5561
* A random seed to random an initial vector.
5662
*/
57-
val seed = new LongParam(
58-
this, "seed", "a random seed to random an initial vector", Some(Utils.random.nextLong()))
63+
final val seed = new LongParam(this, "seed", "a random seed to random an initial vector")
64+
65+
setDefault(seed -> Utils.random.nextLong())
5966

6067
/** @group getParam */
61-
def getSeed: Long = get(seed)
68+
def getSeed: Long = getOrDefault(seed)
6269

6370
/**
6471
* The minimum count of words that can be kept in training set.
6572
*/
66-
val minCount = new IntParam(
67-
this, "minCount", "the minimum count of words to filter words", Some(5))
73+
final val minCount = new IntParam(this, "minCount", "the minimum count of words to filter words")
6874

69-
/** @group getParam */
70-
def getMinCount: Int = get(minCount)
71-
72-
/**
73-
* The column name of the output column - synonyms.
74-
*/
75-
val synonymsCol = new Param[String](this, "synonymsCol", "Synonyms column name")
75+
setDefault(minCount -> 5)
7676

7777
/** @group getParam */
78-
def getSynonymsCol: String = get(synonymsCol)
78+
def getMinCount: Int = getOrDefault(minCount)
7979

8080
/**
81-
* The column name of the output column - code.
81+
* Validate and transform the input schema.
8282
*/
83-
val codeCol = new Param[String](this, "codeCol", "Code column name")
84-
85-
/** @group getParam */
86-
def getCodeCol: String = get(codeCol)
87-
88-
/**
89-
* The number of synonyms that you want to have.
90-
*/
91-
val numSynonyms = new IntParam(this, "numSynonyms", "number of synonyms to find", Some(0))
92-
93-
/** @group getParam */
94-
def getNumSynonyms: Int = get(numSynonyms)
83+
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
84+
val map = extractParamMap(paramMap)
85+
SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(new StringType, false))
86+
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
87+
}
9588
}
9689

9790
/**
@@ -100,16 +93,19 @@ private[feature] trait Word2VecParams extends Params
10093
* natural language processing or machine learning process.
10194
*/
10295
@AlphaComponent
103-
class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
96+
final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
10497

10598
/** @group setParam */
10699
def setInputCol(value: String): this.type = set(inputCol, value)
107100

101+
/** @group setParam */
102+
def setOutputCol(value: String): this.type = set(outputCol, value)
103+
108104
/** @group setParam */
109105
def setVectorSize(value: Int) = set(vectorSize, value)
110106

111107
/** @group setParam */
112-
def setLearningRate(value: Double) = set(learningRate, value)
108+
def setStepSize(value: Double) = set(stepSize, value)
113109

114110
/** @group setParam */
115111
def setNumPartitions(value: Int) = set(numPartitions, value)
@@ -125,10 +121,10 @@ class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
125121

126122
override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = {
127123
transformSchema(dataset.schema, paramMap, logging = true)
128-
val map = this.paramMap ++ paramMap
124+
val map = extractParamMap(paramMap)
129125
val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v }
130126
val wordVectors = new feature.Word2Vec()
131-
.setLearningRate(map(learningRate))
127+
.setLearningRate(map(stepSize))
132128
.setMinCount(map(minCount))
133129
.setNumIterations(map(maxIter))
134130
.setNumPartitions(map(numPartitions))
@@ -141,11 +137,7 @@ class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
141137
}
142138

143139
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
144-
val map = this.paramMap ++ paramMap
145-
val inputType = schema(map(inputCol)).dataType
146-
require(inputType.isInstanceOf[ArrayType],
147-
s"Input column ${map(inputCol)} must be a Iterable[String] column")
148-
schema
140+
validateAndTransformSchema(schema, paramMap)
149141
}
150142
}
151143

@@ -158,81 +150,30 @@ class Word2VecModel private[ml] (
158150
override val parent: Word2Vec,
159151
override val fittingParamMap: ParamMap,
160152
wordVectors: feature.Word2VecModel)
161-
extends Model[Word2VecModel] with Word2VecParams {
153+
extends Model[Word2VecModel] with Word2VecBase {
162154

163155
/** @group setParam */
164156
def setInputCol(value: String): this.type = set(inputCol, value)
165157

166-
/** @group setParam */
167-
def setSynonymsCol(value: String): this.type = set(synonymsCol, value)
168-
169-
/** @group setParam */
170-
def setNumSynonyms(value: Int): this.type = set(numSynonyms, value)
171-
172-
/** @group setParam */
173-
def setCodeCol(value: String): this.type = set(codeCol, value)
158+
/**@group setParam */
159+
def setOutputCol(value: String): this.type = set(outputCol, value)
174160

175161
/**
176-
* The transforming process of `Word2Vec` model has two approaches - 1. Transform a word of
177-
* `String` into a code of `Vector`; 2. Find n (given by you) synonyms of a given word.
178-
*
179-
* Note. Currently we only support finding synonyms for word of `String`, not `Vector`.
162+
* Transform a sentence column to a vector column to represent the whole sentence.
180163
*/
181164
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
182165
transformSchema(dataset.schema, paramMap, logging = true)
183-
val map = this.paramMap ++ paramMap
184-
185-
var tmpData = dataset
186-
var numColsOutput = 0
187-
188-
if (map(codeCol) != "") {
189-
val word2vec: String => Vector = (word) => wordVectors.transform(word)
190-
tmpData = tmpData.withColumn(map(codeCol),
191-
callUDF(word2vec, new VectorUDT, col(map(inputCol))))
192-
numColsOutput += 1
193-
}
194-
195-
if (map(synonymsCol) != "" & map(numSynonyms) > 0) {
196-
// TODO We will add finding synonyms for code of `Vector`.
197-
val findSynonyms = udf { (word: String) =>
198-
wordVectors.findSynonyms(word, map(numSynonyms)).toMap : Map[String, Double]
166+
val map = extractParamMap(paramMap)
167+
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
168+
val word2Vec = udf { v: Seq[String] =>
169+
v.map(bWordVectors.value.transform).foldLeft(Vectors.zeros(map(vectorSize))) { (cum, vec) =>
170+
Vectors.dense(cum.toArray.zip(vec.toArray).map(x => x._1 + x._2))
199171
}
200-
tmpData = tmpData.withColumn(map(synonymsCol), findSynonyms(col(map(inputCol))))
201-
numColsOutput += 1
202172
}
203-
204-
if (numColsOutput == 0) {
205-
this.logWarning(s"$uid: Word2VecModel.transform() was called as NOOP" +
206-
s" since no output columns were set.")
207-
}
208-
209-
tmpData
173+
dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
210174
}
211175

212176
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
213-
val map = this.paramMap ++ paramMap
214-
215-
val inputType = schema(map(inputCol)).dataType
216-
require(inputType.isInstanceOf[StringType],
217-
s"Input column ${map(inputCol)} must be a string column")
218-
219-
var outputFields = schema.fields
220-
221-
if (map(codeCol) != "") {
222-
require(!schema.fieldNames.contains(map(codeCol)),
223-
s"Output column ${map(codeCol)} already exists.")
224-
outputFields = outputFields :+ StructField(map(codeCol), new VectorUDT, nullable = false)
225-
}
226-
227-
if (map(synonymsCol) != "") {
228-
require(!schema.fieldNames.contains(map(synonymsCol)),
229-
s"Output column ${map(synonymsCol)} already exists.")
230-
require(map(numSynonyms) > 0,
231-
s"Number of synonyms should larger than 0")
232-
outputFields = outputFields :+
233-
StructField(map(synonymsCol), MapType(StringType, DoubleType), nullable = false)
234-
}
235-
236-
StructType(outputFields)
177+
validateAndTransformSchema(schema, paramMap)
237178
}
238179
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ private[shared] object SharedParamsCodeGen {
4848
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
4949
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
5050
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
51-
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"))
51+
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
52+
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
5253

5354
val code = genSharedParams(params)
5455
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,21 @@ trait HasTol extends Params {
310310
/** @group getParam */
311311
final def getTol: Double = getOrDefault(tol)
312312
}
313+
314+
/**
315+
* :: DeveloperApi ::
316+
* Trait for shared param stepSize.
317+
*/
318+
@DeveloperApi
319+
trait HasStepSize extends Params {
320+
321+
/**
322+
* Param for Step size to be used for each iteration of optimization..
323+
* @group param
324+
*/
325+
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.")
326+
327+
/** @group getParam */
328+
final def getStepSize: Double = getOrDefault(stepSize)
329+
}
313330
// scalastyle:on

0 commit comments

Comments
 (0)