Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ private[python] class PythonMLLibAPI extends Serializable {
* @param numPartitions number of partitions
* @param numIterations number of iterations
* @param seed initial seed for random generator
* @param windowSize size of window
* @return A handle to java Word2VecModelWrapper instance at python side
*/
def trainWord2VecModel(
Expand All @@ -680,14 +681,16 @@ private[python] class PythonMLLibAPI extends Serializable {
numPartitions: Int,
numIterations: Int,
seed: Long,
minCount: Int): Word2VecModelWrapper = {
minCount: Int,
windowSize: Int): Word2VecModelWrapper = {
val word2vec = new Word2Vec()
.setVectorSize(vectorSize)
.setLearningRate(learningRate)
.setNumPartitions(numPartitions)
.setNumIterations(numIterations)
.setSeed(seed)
.setMinCount(minCount)
.setWindowSize(windowSize)
try {
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
new Word2VecModelWrapper(model)
Expand Down
28 changes: 23 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,28 +2173,31 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
minCount = Param(Params._dummy(), "minCount",
"the minimum number of times a token must appear to be included in the " +
"word2vec model's vocabulary", typeConverter=TypeConverters.toInt)
windowSize = Param(Params._dummy(), "windowSize",
"the window size (context words from [-window, window]). Default value is 5",
typeConverter=TypeConverters.toInt)

@keyword_only
def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
seed=None, inputCol=None, outputCol=None):
seed=None, inputCol=None, outputCol=None, windowSize=5):
"""
__init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
seed=None, inputCol=None, outputCol=None)
seed=None, inputCol=None, outputCol=None, windowSize=5)
"""
super(Word2Vec, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
seed=None)
seed=None, windowSize=5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("1.4.0")
def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
seed=None, inputCol=None, outputCol=None):
seed=None, inputCol=None, outputCol=None, windowSize=5):
"""
setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \
inputCol=None, outputCol=None)
inputCol=None, outputCol=None, windowSize=5)
Sets params for this Word2Vec.
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -2245,6 +2248,21 @@ def getMinCount(self):
"""
return self.getOrDefault(self.minCount)

@since("2.0.0")
def setWindowSize(self, value):
"""
Sets the value of :py:attr:`windowSize`.
"""
self._set(windowSize=value)
return self

@since("2.0.0")
def getWindowSize(self):
"""
Gets the value of windowSize or its default value.
"""
return self.getOrDefault(self.windowSize)

def _create_model(self, java_model):
return Word2VecModel(java_model)

Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def test_param_property_error(self):
params = param_store.params # should not invoke the property 'test_property'
self.assertEqual(len(params), 1)

def test_word2vec_param(self):
model = Word2Vec().setWindowSize(6)
# Check windowSize is set properly
self.assertEqual(model.getWindowSize(), 6)


class FeatureTests(PySparkTestCase):

Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def __init__(self):
self.numIterations = 1
self.seed = random.randint(0, sys.maxsize)
self.minCount = 5
self.windowSize = 5

@since('1.2.0')
def setVectorSize(self, vectorSize):
Expand Down Expand Up @@ -658,6 +659,14 @@ def setMinCount(self, minCount):
self.minCount = minCount
return self

@since('2.0.0')
def setWindowSize(self, windowSize):
"""
Sets window size (default: 5).
"""
self.windowSize = windowSize
return self

@since('1.2.0')
def fit(self, data):
"""
Expand All @@ -671,7 +680,7 @@ def fit(self, data):
jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), int(self.seed),
int(self.minCount))
int(self.minCount), int(self.windowSize))
return Word2VecModel(jmodel)


Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,13 +1026,15 @@ def test_word2vec_setters(self):
.setNumPartitions(2) \
.setNumIterations(10) \
.setSeed(1024) \
.setMinCount(3)
.setMinCount(3) \
.setWindowSize(6)
self.assertEqual(model.vectorSize, 2)
self.assertTrue(model.learningRate < 0.02)
self.assertEqual(model.numPartitions, 2)
self.assertEqual(model.numIterations, 10)
self.assertEqual(model.seed, 1024)
self.assertEqual(model.minCount, 3)
self.assertEqual(model.windowSize, 6)

def test_word2vec_get_vectors(self):
data = [
Expand Down