|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.ml.feature |
| 19 | + |
| 20 | +import com.sun.tools.javac.code.TypeTag |
| 21 | +import org.apache.spark.annotation.AlphaComponent |
| 22 | +import org.apache.spark.annotation.AlphaComponent |
| 23 | +import org.apache.spark.ml.Estimator |
| 24 | +import org.apache.spark.ml.Model |
| 25 | +import org.apache.spark.ml._ |
| 26 | +import org.apache.spark.ml.param.HasInputCol |
| 27 | +import org.apache.spark.ml.param.HasOutputCol |
| 28 | +import org.apache.spark.ml.param.ParamMap |
| 29 | +import org.apache.spark.ml.param.Params |
| 30 | +import org.apache.spark.ml.param._ |
| 31 | +import org.apache.spark.mllib.feature |
| 32 | +import org.apache.spark.mllib.feature |
| 33 | +import org.apache.spark.mllib.feature.Word2Vec |
| 34 | +import org.apache.spark.mllib.linalg.Vector |
| 35 | +import org.apache.spark.mllib.linalg.VectorUDT |
| 36 | +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} |
| 37 | +import org.apache.spark.sql.DataFrame |
| 38 | +import org.apache.spark.sql.Row |
| 39 | +import org.apache.spark.sql._ |
| 40 | +import org.apache.spark.sql.functions._ |
| 41 | +import org.apache.spark.sql.functions._ |
| 42 | +import org.apache.spark.sql.types._ |
| 43 | +import org.apache.spark.util.Utils |
| 44 | + |
| 45 | +import scala.reflect.ClassTag |
| 46 | + |
| 47 | +/** |
| 48 | + * Params for [[StandardScaler]] and [[StandardScalerModel]]. |
| 49 | + */ |
| 50 | +private[feature] trait Word2VecParams extends Params with HasInputCol { |
| 51 | + val vectorSize = new IntParam(this, "vectorSize", "", Some(100)) |
| 52 | + def getVectorSize: Int = get(vectorSize) |
| 53 | + |
| 54 | + val learningRate = new DoubleParam(this, "learningRate", "", Some(0.025)) |
| 55 | + def getLearningRate: Double = get(learningRate) |
| 56 | + |
| 57 | + val numPartitions = new IntParam(this, "numPartitions", "", Some(1)) |
| 58 | + def getNumPartitions: Int = get(numPartitions) |
| 59 | + |
| 60 | + val numIterations = new IntParam(this, "numIterations", "", Some(1)) |
| 61 | + def getNumIterations: Int = get(numIterations) |
| 62 | + |
| 63 | + val seed = new LongParam(this, "seed", "", Some(Utils.random.nextLong())) |
| 64 | + def getSeed: Long = get(seed) |
| 65 | + |
| 66 | + val minCount = new IntParam(this, "minCount", "", Some(5)) |
| 67 | + def getMinCount: Int = get(minCount) |
| 68 | + |
| 69 | + val synonymsCol = new Param[String](this, "synonymsCol", "Synonyms column name") |
| 70 | + def getSynonymsCol: String = get(synonymsCol) |
| 71 | + |
| 72 | + val codeCol = new Param[String](this, "codeCol", "Code column name") |
| 73 | + def getCodeCol: String = get(codeCol) |
| 74 | + |
| 75 | + val numSynonyms = new IntParam(this, "numSynonyms", "number of synonyms to find", Some(0)) |
| 76 | + def getNumSynonyms: Int = get(numSynonyms) |
| 77 | + |
| 78 | + type S <: Iterable[String] |
| 79 | +} |
| 80 | + |
| 81 | +/** |
| 82 | + * :: AlphaComponent :: |
| 83 | + * Standardizes features by removing the mean and scaling to unit variance using column summary |
| 84 | + * statistics on the samples in the training set. |
| 85 | + */ |
| 86 | +@AlphaComponent |
| 87 | +class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams { |
| 88 | + |
| 89 | + /** @group setParam */ |
| 90 | + def setInputCol(value: String): this.type = set(inputCol, value) |
| 91 | + def setVectorSize(value: Int) = set(vectorSize, value) |
| 92 | + def setLearningRate(value: Double) = set(learningRate, value) |
| 93 | + def setNumPartitions(value: Int) = set(numPartitions, value) |
| 94 | + def setNumIterations(value: Int) = set(numIterations, value) |
| 95 | + def setSeed(value: Long) = set(seed, value) |
| 96 | + def setMinCount(value: Int) = set(minCount, value) |
| 97 | + |
| 98 | + override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = { |
| 99 | + transformSchema(dataset.schema, paramMap, logging = true) |
| 100 | + val map = this.paramMap ++ paramMap |
| 101 | + val input = dataset.select(map(inputCol)).map { case Row(v: S) => v } |
| 102 | + val wordVectors = new feature.Word2Vec() |
| 103 | + .setLearningRate(map(learningRate)) |
| 104 | + .setMinCount(map(minCount)) |
| 105 | + .setNumIterations(map(numIterations)) |
| 106 | + .setNumPartitions(map(numPartitions)) |
| 107 | + .setSeed(map(seed)) |
| 108 | + .setVectorSize(map(vectorSize)) |
| 109 | + .fit(input) |
| 110 | + val model = new Word2VecModel(this, map, wordVectors) |
| 111 | + Params.inheritValues(map, this, model) |
| 112 | + model |
| 113 | + } |
| 114 | + |
| 115 | + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { |
| 116 | + val map = this.paramMap ++ paramMap |
| 117 | + val inputType = schema(map(inputCol)).dataType |
| 118 | + require(inputType.isInstanceOf[S], |
| 119 | + s"Input column ${map(inputCol)} must be a Iterable[String] column") |
| 120 | + schema |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +/** |
| 125 | + * :: AlphaComponent :: |
| 126 | + * Model fitted by [[StandardScaler]]. |
| 127 | + */ |
| 128 | +@AlphaComponent |
| 129 | +class Word2VecModel private[ml] ( |
| 130 | + override val parent: Word2Vec, |
| 131 | + override val fittingParamMap: ParamMap, |
| 132 | + wordVectors: feature.Word2VecModel) |
| 133 | + extends Model[Word2VecModel] with Word2VecParams { |
| 134 | + |
| 135 | + /** @group setParam */ |
| 136 | + def setInputCol(value: String): this.type = set(inputCol, value) |
| 137 | + |
| 138 | + /** @group setParam */ |
| 139 | + def setSynonymsCol(value: String): this.type = set(synonymsCol, value) |
| 140 | + |
| 141 | + /** @group setParam */ |
| 142 | + def setCodeCol(value: String): this.type = set(codeCol, value) |
| 143 | + |
| 144 | + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { |
| 145 | + transformSchema(dataset.schema, paramMap, logging = true) |
| 146 | + val map = this.paramMap ++ paramMap |
| 147 | + |
| 148 | + var tmpData = dataset |
| 149 | + var numColsOutput = 0 |
| 150 | + |
| 151 | + if (map(codeCol) != "") { |
| 152 | + val word2vec: String => Vector = (word) => wordVectors.transform(word) |
| 153 | + tmpData = tmpData.withColumn(map(codeCol), callUDF(word2vec, new VectorUDT, col(map(inputCol)))) |
| 154 | + numColsOutput += 1 |
| 155 | + } |
| 156 | + |
| 157 | + if (map(synonymsCol) != "" & map(numSynonyms) > 0) { |
| 158 | + val findSynonyms = udf { (word: String) => wordVectors.findSynonyms(word, map(numSynonyms)) : Array[(String, Double)] } |
| 159 | + tmpData = tmpData.withColumn(map(synonymsCol), findSynonyms(col(map(inputCol)))) |
| 160 | + numColsOutput += 1 |
| 161 | + } |
| 162 | + |
| 163 | + if (numColsOutput == 0) { |
| 164 | + this.logWarning(s"$uid: Word2VecModel.transform() was called as NOOP" + |
| 165 | + s" since no output columns were set.") |
| 166 | + } |
| 167 | + |
| 168 | + tmpData |
| 169 | + } |
| 170 | + |
| 171 | + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { |
| 172 | + val map = this.paramMap ++ paramMap |
| 173 | + |
| 174 | + val inputType = schema(map(inputCol)).dataType |
| 175 | + require(inputType.isInstanceOf[String], |
| 176 | + s"Input column ${map(inputCol)} must be a string column") |
| 177 | + |
| 178 | + var outputFields = schema.fields |
| 179 | + |
| 180 | + if (map(codeCol) != "") { |
| 181 | + require(!schema.fieldNames.contains(map(codeCol)), |
| 182 | + s"Output column ${map(codeCol)} already exists.") |
| 183 | + outputFields = outputFields :+ StructField(map(codeCol), new VectorUDT, false) |
| 184 | + } |
| 185 | + |
| 186 | + if (map(synonymsCol) != "") { |
| 187 | + require(!schema.fieldNames.contains(map(synonymsCol)), |
| 188 | + s"Output column ${map(synonymsCol)} already exists.") |
| 189 | + require(map(numSynonyms) > 0, |
| 190 | + s"Number of synonyms should larger than 0") |
| 191 | + outputFields = outputFields :+ StructField(map(synonymsCol), , false) |
| 192 | + } |
| 193 | + |
| 194 | + StructType(outputFields) |
| 195 | + } |
| 196 | +} |
0 commit comments