|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +""" |
| 19 | +Python package for feature in MLlib. |
| 20 | +""" |
| 21 | +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer |
| 22 | + |
| 23 | +from pyspark.mllib.linalg import _convert_to_vector |
| 24 | + |
| 25 | +__all__ = ['Word2Vec', 'Word2VecModel'] |
| 26 | + |
| 27 | + |
| 28 | +class Word2VecModel(object): |
| 29 | + """ |
| 30 | + class for Word2Vec model |
| 31 | + """ |
| 32 | + def __init__(self, sc, java_model): |
| 33 | + """ |
| 34 | + :param sc: Spark context |
| 35 | + :param java_model: Handle to Java model object |
| 36 | + """ |
| 37 | + self._sc = sc |
| 38 | + self._java_model = java_model |
| 39 | + |
| 40 | + def __del__(self): |
| 41 | + self._sc._gateway.detach(self._java_model) |
| 42 | + |
| 43 | + def transform(self, word): |
| 44 | + """ |
| 45 | + :param word: a word |
| 46 | + :return: vector representation of word |
| 47 | + Transforms a word to its vector representation |
| 48 | +
|
| 49 | + Note: local use only |
| 50 | + """ |
| 51 | + # TODO: make transform usable in RDD operations from python side |
| 52 | + result = self._java_model.transform(word) |
| 53 | + return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result))) |
| 54 | + |
| 55 | + def findSynonyms(self, x, num): |
| 56 | + """ |
| 57 | + :param x: a word or a vector representation of word |
| 58 | + :param num: number of synonyms to find |
| 59 | + :return: array of (word, cosineSimilarity) |
| 60 | + Find synonyms of a word |
| 61 | +
|
| 62 | + Note: local use only |
| 63 | + """ |
| 64 | + # TODO: make findSynonyms usable in RDD operations from python side |
| 65 | + ser = PickleSerializer() |
| 66 | + if type(x) == str: |
| 67 | + jlist = self._java_model.findSynonyms(x, num) |
| 68 | + else: |
| 69 | + bytes = bytearray(ser.dumps(_convert_to_vector(x))) |
| 70 | + vec = self._sc._jvm.SerDe.loads(bytes) |
| 71 | + jlist = self._java_model.findSynonyms(vec, num) |
| 72 | + words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist))) |
| 73 | + return zip(words, similarity) |
| 74 | + |
| 75 | + |
| 76 | +class Word2Vec(object): |
| 77 | + """ |
| 78 | + Word2Vec creates vector representation of words in a text corpus. |
| 79 | + The algorithm first constructs a vocabulary from the corpus |
| 80 | + and then learns vector representation of words in the vocabulary. |
| 81 | + The vector representation can be used as features in |
| 82 | + natural language processing and machine learning algorithms. |
| 83 | +
|
| 84 | + We used skip-gram model in our implementation and hierarchical softmax |
| 85 | + method to train the model. The variable names in the implementation |
| 86 | + matches the original C implementation. |
| 87 | + For original C implementation, see https://code.google.com/p/word2vec/ |
| 88 | + For research papers, see |
| 89 | + Efficient Estimation of Word Representations in Vector Space |
| 90 | + and |
| 91 | + Distributed Representations of Words and Phrases and their Compositionality. |
| 92 | +
|
| 93 | + >>> sentence = "a b " * 100 + "a c " * 10 |
| 94 | + >>> localDoc = [sentence, sentence] |
| 95 | + >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) |
| 96 | + >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) |
| 97 | + >>> syms = model.findSynonyms("a", 2) |
| 98 | + >>> str(syms[0][0]) |
| 99 | + 'b' |
| 100 | + >>> str(syms[1][0]) |
| 101 | + 'c' |
| 102 | + >>> len(syms) |
| 103 | + 2 |
| 104 | + >>> vec = model.transform("a") |
| 105 | + >>> len(vec) |
| 106 | + 10 |
| 107 | + >>> syms = model.findSynonyms(vec, 2) |
| 108 | + >>> str(syms[0][0]) |
| 109 | + 'b' |
| 110 | + >>> str(syms[1][0]) |
| 111 | + 'c' |
| 112 | + >>> len(syms) |
| 113 | + 2 |
| 114 | + """ |
| 115 | + def __init__(self): |
| 116 | + """ |
| 117 | + Construct Word2Vec instance |
| 118 | + """ |
| 119 | + self.vectorSize = 100 |
| 120 | + self.learningRate = 0.025 |
| 121 | + self.numPartitions = 1 |
| 122 | + self.numIterations = 1 |
| 123 | + self.seed = 42L |
| 124 | + |
| 125 | + def setVectorSize(self, vectorSize): |
| 126 | + """ |
| 127 | + Sets vector size (default: 100). |
| 128 | + """ |
| 129 | + self.vectorSize = vectorSize |
| 130 | + return self |
| 131 | + |
| 132 | + def setLearningRate(self, learningRate): |
| 133 | + """ |
| 134 | + Sets initial learning rate (default: 0.025). |
| 135 | + """ |
| 136 | + self.learningRate = learningRate |
| 137 | + return self |
| 138 | + |
| 139 | + def setNumPartitions(self, numPartitions): |
| 140 | + """ |
| 141 | + Sets number of partitions (default: 1). Use a small number for accuracy. |
| 142 | + """ |
| 143 | + self.numPartitions = numPartitions |
| 144 | + return self |
| 145 | + |
| 146 | + def setNumIterations(self, numIterations): |
| 147 | + """ |
| 148 | + Sets number of iterations (default: 1), which should be smaller than or equal to number of |
| 149 | + partitions. |
| 150 | + """ |
| 151 | + self.numIterations = numIterations |
| 152 | + return self |
| 153 | + |
| 154 | + def setSeed(self, seed): |
| 155 | + """ |
| 156 | + Sets random seed. |
| 157 | + """ |
| 158 | + self.seed = seed |
| 159 | + return self |
| 160 | + |
| 161 | + def fit(self, data): |
| 162 | + """ |
| 163 | + Computes the vector representation of each word in vocabulary. |
| 164 | +
|
| 165 | + :param data: training data. RDD of subtype of Iterable[String] |
| 166 | + :return: python Word2VecModel instance |
| 167 | + """ |
| 168 | + sc = data.context |
| 169 | + ser = PickleSerializer() |
| 170 | + vectorSize = self.vectorSize |
| 171 | + learningRate = self.learningRate |
| 172 | + numPartitions = self.numPartitions |
| 173 | + numIterations = self.numIterations |
| 174 | + seed = self.seed |
| 175 | + |
| 176 | + model = sc._jvm.PythonMLLibAPI().trainWord2Vec( |
| 177 | + data._to_java_object_rdd(), vectorSize, |
| 178 | + learningRate, numPartitions, numIterations, seed) |
| 179 | + return Word2VecModel(sc, model) |
| 180 | + |
| 181 | + |
| 182 | +def _test(): |
| 183 | + import doctest |
| 184 | + from pyspark import SparkContext |
| 185 | + globs = globals().copy() |
| 186 | + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) |
| 187 | + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) |
| 188 | + globs['sc'].stop() |
| 189 | + if failure_count: |
| 190 | + exit(-1) |
| 191 | + |
| 192 | +if __name__ == "__main__": |
| 193 | + _test() |
0 commit comments