Skip to content

Commit cdef9f4

Browse files
committed
add missing comments
1 parent b7447eb commit cdef9f4

File tree

3 files changed

+58
-22
lines changed

3 files changed

+58
-22
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,24 +289,28 @@ class PythonMLLibAPI extends Serializable {
289289
* handle to the Java object instead of the content of the Java object.
290290
* Extra care needs to be taken in the Python code to ensure it gets freed on
291291
* exit; see the Py4J documentation.
292-
* @param dataJRDD Input JavaRDD
292+
* @param dataJRDD input JavaRDD
293+
* @param vectorSize size of vector
294+
* @param learningRate initial learning rate
295+
* @param numPartitions number of partitions
296+
* @param numIterations number of iterations
297+
* @param seed initial seed for random generator
293298
* @return A handle to java Word2VecModelWrapper instance at python side
294299
*/
295300
def trainWord2Vec(
296301
dataJRDD: JavaRDD[java.util.ArrayList[String]],
297302
vectorSize: Int,
298-
startingAlpha: Double,
303+
learningRate: Double,
299304
numPartitions: Int,
300305
numIterations: Int,
301-
seed: Long
302-
): Word2VecModelWrapper = {
306+
seed: Long): Word2VecModelWrapper = {
303307
val data = dataJRDD.rdd.cache()
304308
val word2vec = new Word2Vec()
305-
.setVectorSize(vectorSize)
306-
.setLearningRate(startingAlpha)
307-
.setNumPartitions(numPartitions)
308-
.setNumIterations(numIterations)
309-
.setSeed(seed)
309+
.setVectorSize(vectorSize)
310+
.setLearningRate(learningRate)
311+
.setNumPartitions(numPartitions)
312+
.setNumIterations(numIterations)
313+
.setSeed(seed)
310314
val model = word2vec.fit(data)
311315
new Word2VecModelWrapper(model)
312316
}

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ private case class VocabWord(
6767
class Word2Vec extends Serializable with Logging {
6868

6969
private var vectorSize = 100
70-
private var startingAlpha = 0.025
70+
private var learningRate = 0.025
7171
private var numPartitions = 1
7272
private var numIterations = 1
7373
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
8484
* Sets initial learning rate (default: 0.025).
8585
*/
8686
def setLearningRate(learningRate: Double): this.type = {
87-
this.startingAlpha = learningRate
87+
this.learningRate = learningRate
8888
this
8989
}
9090

@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
286286
val syn0Global =
287287
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
288288
val syn1Global = new Array[Float](vocabSize * vectorSize)
289-
var alpha = startingAlpha
289+
var alpha = learningRate
290290
for (k <- 1 to numIterations) {
291291
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
292292
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
300300
lwc = wordCount
301301
// TODO: discount by iteration?
302302
alpha =
303-
startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
304-
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
303+
learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
304+
if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
305305
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
306306
}
307307
wc += sentence.size

python/pyspark/mllib/Word2Vec.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
"""
1919
Python package for Word2Vec in MLlib.
2020
"""
21-
from numpy import random
22-
2321
from sys import maxint
2422

23+
from numpy import random
24+
2525
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
2626

2727
from pyspark.mllib.linalg import _convert_to_vector
@@ -46,15 +46,22 @@ def __del__(self):
4646

4747
def transform(self, word):
4848
"""
49-
local use only
49+
:param word: a word
50+
:return: vector representation of word
51+
52+
Note: local use only
5053
TODO: make transform usable in RDD operations from python side
5154
"""
5255
result = self._java_model.transform(word)
5356
return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
5457

5558
def findSynonyms(self, x, num):
5659
"""
57-
local use only
60+
:param x: a word or a vector representation of word
61+
:param num: number of synonyms to find
62+
:return: array of (word, cosineSimilarity)
63+
64+
Note: local use only
5865
TODO: make findSynonyms usable in RDD operations from python side
5966
"""
6067
jlist = self._java_model.findSynonyms(x, num)
@@ -95,45 +102,70 @@ class Word2Vec(object):
95102
10
96103
"""
97104
def __init__(self):
105+
"""
106+
Construct Word2Vec instance
107+
"""
98108
self.vectorSize = 100
99-
self.startingAlpha = 0.025
109+
self.learningRate = 0.025
100110
self.numPartitions = 1
101111
self.numIterations = 1
102112
self.seed = random.randint(0, high=maxint)
103113

104114
def setVectorSize(self, vectorSize):
115+
"""
116+
Sets vector size (default: 100).
117+
"""
105118
self.vectorSize = vectorSize
106119
return self
107120

108121
def setLearningRate(self, learningRate):
109-
self.startingAlpha = learningRate
122+
"""
123+
Sets initial learning rate (default: 0.025).
124+
"""
125+
self.learningRate = learningRate
110126
return self
111127

112128
def setNumPartitions(self, numPartitions):
129+
"""
130+
Sets number of partitions (default: 1). Use a small number for accuracy.
131+
"""
113132
self.numPartitions = numPartitions
114133
return self
115134

116135
def setNumIterations(self, numIterations):
136+
"""
137+
Sets number of iterations (default: 1), which should be smaller than or equal to number of
138+
partitions.
139+
"""
117140
self.numIterations = numIterations
118141
return self
119142

120143
def setSeed(self, seed):
144+
"""
145+
Sets random seed (default: a random long integer).
146+
"""
121147
self.seed = seed
122148
return self
123149

124150
def fit(self, data):
151+
"""
152+
Computes the vector representation of each word in vocabulary.
153+
154+
:param data: training data.
155+
:return: python Word2VecModel instance
156+
"""
125157
sc = data.context
126158
ser = PickleSerializer()
127159
vectorSize = self.vectorSize
128-
startingAlpha = self.startingAlpha
160+
learningRate = self.learningRate
129161
numPartitions = self.numPartitions
130162
numIterations = self.numIterations
131163
seed = self.seed
132164

133165
# cached = data._reserialize(AutoBatchedSerializer(ser)).cache()
134166
model = sc._jvm.PythonMLLibAPI().trainWord2Vec(
135167
data._to_java_object_rdd(), vectorSize,
136-
startingAlpha, numPartitions, numIterations, seed)
168+
learningRate, numPartitions, numIterations, seed)
137169
return Word2VecModel(sc, model)
138170

139171

0 commit comments

Comments
 (0)