|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.ml.feature |
| 19 | + |
| 20 | +import org.apache.hadoop.fs.Path |
| 21 | + |
| 22 | +import org.apache.spark.SparkException |
| 23 | +import org.apache.spark.annotation.{Experimental, Since} |
| 24 | +import org.apache.spark.ml.{Estimator, Model} |
| 25 | +import org.apache.spark.ml.param._ |
| 26 | +import org.apache.spark.ml.param.shared.HasInputCols |
| 27 | +import org.apache.spark.ml.util._ |
| 28 | +import org.apache.spark.sql.{DataFrame, Dataset, Row} |
| 29 | +import org.apache.spark.sql.functions._ |
| 30 | +import org.apache.spark.sql.types._ |
| 31 | + |
| 32 | +/** |
| 33 | + * Params for [[Imputer]] and [[ImputerModel]]. |
| 34 | + */ |
| 35 | +private[feature] trait ImputerParams extends Params with HasInputCols { |
| 36 | + |
| 37 | + /** |
| 38 | + * The imputation strategy. |
| 39 | + * If "mean", then replace missing values using the mean value of the feature. |
| 40 | + * If "median", then replace missing values using the approximate median value of the feature. |
| 41 | + * Default: mean |
| 42 | + * |
| 43 | + * @group param |
| 44 | + */ |
| 45 | + final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + |
| 46 | + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + |
| 47 | + s"If ${Imputer.median}, then replace missing values using the median value of the feature.", |
| 48 | + ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) |
| 49 | + |
| 50 | + /** @group getParam */ |
| 51 | + def getStrategy: String = $(strategy) |
| 52 | + |
| 53 | + /** |
| 54 | + * The placeholder for the missing values. All occurrences of missingValue will be imputed. |
| 55 | + * Note that null values are always treated as missing. |
| 56 | + * Default: Double.NaN |
| 57 | + * |
| 58 | + * @group param |
| 59 | + */ |
| 60 | + final val missingValue: DoubleParam = new DoubleParam(this, "missingValue", |
| 61 | + "The placeholder for the missing values. All occurrences of missingValue will be imputed") |
| 62 | + |
| 63 | + /** @group getParam */ |
| 64 | + def getMissingValue: Double = $(missingValue) |
| 65 | + |
| 66 | + /** |
| 67 | + * Param for output column names. |
| 68 | + * @group param |
| 69 | + */ |
| 70 | + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", |
| 71 | + "output column names") |
| 72 | + |
| 73 | + /** @group getParam */ |
| 74 | + final def getOutputCols: Array[String] = $(outputCols) |
| 75 | + |
| 76 | + /** Validates and transforms the input schema. */ |
| 77 | + protected def validateAndTransformSchema(schema: StructType): StructType = { |
| 78 | + require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + |
| 79 | + s" duplicates: (${$(inputCols).mkString(", ")})") |
| 80 | + require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + |
| 81 | + s" duplicates: (${$(outputCols).mkString(", ")})") |
| 82 | + require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + |
| 83 | + s" and outputCols(${$(outputCols).length}) should have the same length") |
| 84 | + val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => |
| 85 | + val inputField = schema(inputCol) |
| 86 | + SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) |
| 87 | + StructField(outputCol, inputField.dataType, inputField.nullable) |
| 88 | + } |
| 89 | + StructType(schema ++ outputFields) |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +/** |
| 94 | + * :: Experimental :: |
| 95 | + * Imputation estimator for completing missing values, either using the mean or the median |
| 96 | + * of the column in which the missing values are located. The input column should be of |
| 97 | + * DoubleType or FloatType. Currently Imputer does not support categorical features yet |
| 98 | + * (SPARK-15041) and possibly creates incorrect values for a categorical feature. |
| 99 | + * |
| 100 | + * Note that the mean/median value is computed after filtering out missing values. |
| 101 | + * All Null values in the input column are treated as missing, and so are also imputed. For |
| 102 | + * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. |
| 103 | + */ |
| 104 | +@Experimental |
| 105 | +class Imputer @Since("2.2.0")(override val uid: String) |
| 106 | + extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable { |
| 107 | + |
| 108 | + @Since("2.2.0") |
| 109 | + def this() = this(Identifiable.randomUID("imputer")) |
| 110 | + |
| 111 | + /** @group setParam */ |
| 112 | + @Since("2.2.0") |
| 113 | + def setInputCols(value: Array[String]): this.type = set(inputCols, value) |
| 114 | + |
| 115 | + /** @group setParam */ |
| 116 | + @Since("2.2.0") |
| 117 | + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) |
| 118 | + |
| 119 | + /** |
| 120 | + * Imputation strategy. Available options are ["mean", "median"]. |
| 121 | + * @group setParam |
| 122 | + */ |
| 123 | + @Since("2.2.0") |
| 124 | + def setStrategy(value: String): this.type = set(strategy, value) |
| 125 | + |
| 126 | + /** @group setParam */ |
| 127 | + @Since("2.2.0") |
| 128 | + def setMissingValue(value: Double): this.type = set(missingValue, value) |
| 129 | + |
| 130 | + setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) |
| 131 | + |
| 132 | + override def fit(dataset: Dataset[_]): ImputerModel = { |
| 133 | + transformSchema(dataset.schema, logging = true) |
| 134 | + val spark = dataset.sparkSession |
| 135 | + import spark.implicits._ |
| 136 | + val surrogates = $(inputCols).map { inputCol => |
| 137 | + val ic = col(inputCol) |
| 138 | + val filtered = dataset.select(ic.cast(DoubleType)) |
| 139 | + .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) |
| 140 | + if(filtered.take(1).length == 0) { |
| 141 | + throw new SparkException(s"surrogate cannot be computed. " + |
| 142 | + s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})") |
| 143 | + } |
| 144 | + val surrogate = $(strategy) match { |
| 145 | + case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() |
| 146 | + case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head |
| 147 | + } |
| 148 | + surrogate |
| 149 | + } |
| 150 | + |
| 151 | + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates))) |
| 152 | + val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) |
| 153 | + val surrogateDF = spark.createDataFrame(rows, schema) |
| 154 | + copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) |
| 155 | + } |
| 156 | + |
| 157 | + override def transformSchema(schema: StructType): StructType = { |
| 158 | + validateAndTransformSchema(schema) |
| 159 | + } |
| 160 | + |
| 161 | + override def copy(extra: ParamMap): Imputer = defaultCopy(extra) |
| 162 | +} |
| 163 | + |
| 164 | +@Since("2.2.0") |
| 165 | +object Imputer extends DefaultParamsReadable[Imputer] { |
| 166 | + |
| 167 | + /** strategy names that Imputer currently supports. */ |
| 168 | + private[ml] val mean = "mean" |
| 169 | + private[ml] val median = "median" |
| 170 | + |
| 171 | + @Since("2.2.0") |
| 172 | + override def load(path: String): Imputer = super.load(path) |
| 173 | +} |
| 174 | + |
| 175 | +/** |
| 176 | + * :: Experimental :: |
| 177 | + * Model fitted by [[Imputer]]. |
| 178 | + * |
| 179 | + * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are |
| 180 | + * used to replace the missing values in the input DataFrame. |
| 181 | + */ |
| 182 | +@Experimental |
| 183 | +class ImputerModel private[ml]( |
| 184 | + override val uid: String, |
| 185 | + val surrogateDF: DataFrame) |
| 186 | + extends Model[ImputerModel] with ImputerParams with MLWritable { |
| 187 | + |
| 188 | + import ImputerModel._ |
| 189 | + |
| 190 | + /** @group setParam */ |
| 191 | + def setInputCols(value: Array[String]): this.type = set(inputCols, value) |
| 192 | + |
| 193 | + /** @group setParam */ |
| 194 | + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) |
| 195 | + |
| 196 | + override def transform(dataset: Dataset[_]): DataFrame = { |
| 197 | + transformSchema(dataset.schema, logging = true) |
| 198 | + var outputDF = dataset |
| 199 | + val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq |
| 200 | + |
| 201 | + $(inputCols).zip($(outputCols)).zip(surrogates).foreach { |
| 202 | + case ((inputCol, outputCol), surrogate) => |
| 203 | + val inputType = dataset.schema(inputCol).dataType |
| 204 | + val ic = col(inputCol) |
| 205 | + outputDF = outputDF.withColumn(outputCol, |
| 206 | + when(ic.isNull, surrogate) |
| 207 | + .when(ic === $(missingValue), surrogate) |
| 208 | + .otherwise(ic) |
| 209 | + .cast(inputType)) |
| 210 | + } |
| 211 | + outputDF.toDF() |
| 212 | + } |
| 213 | + |
| 214 | + override def transformSchema(schema: StructType): StructType = { |
| 215 | + validateAndTransformSchema(schema) |
| 216 | + } |
| 217 | + |
| 218 | + override def copy(extra: ParamMap): ImputerModel = { |
| 219 | + val copied = new ImputerModel(uid, surrogateDF) |
| 220 | + copyValues(copied, extra).setParent(parent) |
| 221 | + } |
| 222 | + |
| 223 | + @Since("2.2.0") |
| 224 | + override def write: MLWriter = new ImputerModelWriter(this) |
| 225 | +} |
| 226 | + |
| 227 | + |
| 228 | +@Since("2.2.0") |
| 229 | +object ImputerModel extends MLReadable[ImputerModel] { |
| 230 | + |
| 231 | + private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter { |
| 232 | + |
| 233 | + override protected def saveImpl(path: String): Unit = { |
| 234 | + DefaultParamsWriter.saveMetadata(instance, path, sc) |
| 235 | + val dataPath = new Path(path, "data").toString |
| 236 | + instance.surrogateDF.repartition(1).write.parquet(dataPath) |
| 237 | + } |
| 238 | + } |
| 239 | + |
| 240 | + private class ImputerReader extends MLReader[ImputerModel] { |
| 241 | + |
| 242 | + private val className = classOf[ImputerModel].getName |
| 243 | + |
| 244 | + override def load(path: String): ImputerModel = { |
| 245 | + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) |
| 246 | + val dataPath = new Path(path, "data").toString |
| 247 | + val surrogateDF = sqlContext.read.parquet(dataPath) |
| 248 | + val model = new ImputerModel(metadata.uid, surrogateDF) |
| 249 | + DefaultParamsReader.getAndSetParams(model, metadata) |
| 250 | + model |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + @Since("2.2.0") |
| 255 | + override def read: MLReader[ImputerModel] = new ImputerReader |
| 256 | + |
| 257 | + @Since("2.2.0") |
| 258 | + override def load(path: String): ImputerModel = super.load(path) |
| 259 | +} |
0 commit comments