Skip to content

Commit d647aae

Browse files
hhbyyhNick Pentreath
authored andcommitted
[SPARK-13568][ML] Create feature transformer to impute missing values
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13568 It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn. Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc). Currently this PR supports imputation for Double and Vector (null and NaN in Vector). ## How was this patch tested? new unit tests and manual test Author: Yuhao Yang <[email protected]> Author: Yuhao Yang <[email protected]> Author: Yuhao <[email protected]> Closes #11601 from hhbyyh/imputer.
1 parent 1472cac commit d647aae

File tree

2 files changed

+444
-0
lines changed

2 files changed

+444
-0
lines changed
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.hadoop.fs.Path
21+
22+
import org.apache.spark.SparkException
23+
import org.apache.spark.annotation.{Experimental, Since}
24+
import org.apache.spark.ml.{Estimator, Model}
25+
import org.apache.spark.ml.param._
26+
import org.apache.spark.ml.param.shared.HasInputCols
27+
import org.apache.spark.ml.util._
28+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
29+
import org.apache.spark.sql.functions._
30+
import org.apache.spark.sql.types._
31+
32+
/**
33+
* Params for [[Imputer]] and [[ImputerModel]].
34+
*/
35+
private[feature] trait ImputerParams extends Params with HasInputCols {
36+
37+
/**
38+
* The imputation strategy.
39+
* If "mean", then replace missing values using the mean value of the feature.
40+
* If "median", then replace missing values using the approximate median value of the feature.
41+
* Default: mean
42+
*
43+
* @group param
44+
*/
45+
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
46+
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
47+
s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
48+
ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))
49+
50+
/** @group getParam */
51+
def getStrategy: String = $(strategy)
52+
53+
/**
54+
* The placeholder for the missing values. All occurrences of missingValue will be imputed.
55+
* Note that null values are always treated as missing.
56+
* Default: Double.NaN
57+
*
58+
* @group param
59+
*/
60+
final val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
61+
"The placeholder for the missing values. All occurrences of missingValue will be imputed")
62+
63+
/** @group getParam */
64+
def getMissingValue: Double = $(missingValue)
65+
66+
/**
67+
* Param for output column names.
68+
* @group param
69+
*/
70+
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
71+
"output column names")
72+
73+
/** @group getParam */
74+
final def getOutputCols: Array[String] = $(outputCols)
75+
76+
/** Validates and transforms the input schema. */
77+
protected def validateAndTransformSchema(schema: StructType): StructType = {
78+
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
79+
s" duplicates: (${$(inputCols).mkString(", ")})")
80+
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" +
81+
s" duplicates: (${$(outputCols).mkString(", ")})")
82+
require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" +
83+
s" and outputCols(${$(outputCols).length}) should have the same length")
84+
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
85+
val inputField = schema(inputCol)
86+
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
87+
StructField(outputCol, inputField.dataType, inputField.nullable)
88+
}
89+
StructType(schema ++ outputFields)
90+
}
91+
}
92+
93+
/**
94+
* :: Experimental ::
95+
* Imputation estimator for completing missing values, either using the mean or the median
96+
* of the column in which the missing values are located. The input column should be of
97+
* DoubleType or FloatType. Currently Imputer does not support categorical features yet
98+
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
99+
*
100+
* Note that the mean/median value is computed after filtering out missing values.
101+
* All Null values in the input column are treated as missing, and so are also imputed. For
102+
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
103+
*/
104+
@Experimental
105+
class Imputer @Since("2.2.0")(override val uid: String)
106+
extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {
107+
108+
@Since("2.2.0")
109+
def this() = this(Identifiable.randomUID("imputer"))
110+
111+
/** @group setParam */
112+
@Since("2.2.0")
113+
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
114+
115+
/** @group setParam */
116+
@Since("2.2.0")
117+
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
118+
119+
/**
120+
* Imputation strategy. Available options are ["mean", "median"].
121+
* @group setParam
122+
*/
123+
@Since("2.2.0")
124+
def setStrategy(value: String): this.type = set(strategy, value)
125+
126+
/** @group setParam */
127+
@Since("2.2.0")
128+
def setMissingValue(value: Double): this.type = set(missingValue, value)
129+
130+
setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)
131+
132+
override def fit(dataset: Dataset[_]): ImputerModel = {
133+
transformSchema(dataset.schema, logging = true)
134+
val spark = dataset.sparkSession
135+
import spark.implicits._
136+
val surrogates = $(inputCols).map { inputCol =>
137+
val ic = col(inputCol)
138+
val filtered = dataset.select(ic.cast(DoubleType))
139+
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
140+
if(filtered.take(1).length == 0) {
141+
throw new SparkException(s"surrogate cannot be computed. " +
142+
s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
143+
}
144+
val surrogate = $(strategy) match {
145+
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
146+
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
147+
}
148+
surrogate
149+
}
150+
151+
val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
152+
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
153+
val surrogateDF = spark.createDataFrame(rows, schema)
154+
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
155+
}
156+
157+
override def transformSchema(schema: StructType): StructType = {
158+
validateAndTransformSchema(schema)
159+
}
160+
161+
override def copy(extra: ParamMap): Imputer = defaultCopy(extra)
162+
}
163+
164+
@Since("2.2.0")
165+
object Imputer extends DefaultParamsReadable[Imputer] {
166+
167+
/** strategy names that Imputer currently supports. */
168+
private[ml] val mean = "mean"
169+
private[ml] val median = "median"
170+
171+
@Since("2.2.0")
172+
override def load(path: String): Imputer = super.load(path)
173+
}
174+
175+
/**
176+
* :: Experimental ::
177+
* Model fitted by [[Imputer]].
178+
*
179+
* @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are
180+
* used to replace the missing values in the input DataFrame.
181+
*/
182+
@Experimental
183+
class ImputerModel private[ml](
184+
override val uid: String,
185+
val surrogateDF: DataFrame)
186+
extends Model[ImputerModel] with ImputerParams with MLWritable {
187+
188+
import ImputerModel._
189+
190+
/** @group setParam */
191+
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
192+
193+
/** @group setParam */
194+
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
195+
196+
override def transform(dataset: Dataset[_]): DataFrame = {
197+
transformSchema(dataset.schema, logging = true)
198+
var outputDF = dataset
199+
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
200+
201+
$(inputCols).zip($(outputCols)).zip(surrogates).foreach {
202+
case ((inputCol, outputCol), surrogate) =>
203+
val inputType = dataset.schema(inputCol).dataType
204+
val ic = col(inputCol)
205+
outputDF = outputDF.withColumn(outputCol,
206+
when(ic.isNull, surrogate)
207+
.when(ic === $(missingValue), surrogate)
208+
.otherwise(ic)
209+
.cast(inputType))
210+
}
211+
outputDF.toDF()
212+
}
213+
214+
override def transformSchema(schema: StructType): StructType = {
215+
validateAndTransformSchema(schema)
216+
}
217+
218+
override def copy(extra: ParamMap): ImputerModel = {
219+
val copied = new ImputerModel(uid, surrogateDF)
220+
copyValues(copied, extra).setParent(parent)
221+
}
222+
223+
@Since("2.2.0")
224+
override def write: MLWriter = new ImputerModelWriter(this)
225+
}
226+
227+
228+
@Since("2.2.0")
229+
object ImputerModel extends MLReadable[ImputerModel] {
230+
231+
private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {
232+
233+
override protected def saveImpl(path: String): Unit = {
234+
DefaultParamsWriter.saveMetadata(instance, path, sc)
235+
val dataPath = new Path(path, "data").toString
236+
instance.surrogateDF.repartition(1).write.parquet(dataPath)
237+
}
238+
}
239+
240+
private class ImputerReader extends MLReader[ImputerModel] {
241+
242+
private val className = classOf[ImputerModel].getName
243+
244+
override def load(path: String): ImputerModel = {
245+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
246+
val dataPath = new Path(path, "data").toString
247+
val surrogateDF = sqlContext.read.parquet(dataPath)
248+
val model = new ImputerModel(metadata.uid, surrogateDF)
249+
DefaultParamsReader.getAndSetParams(model, metadata)
250+
model
251+
}
252+
}
253+
254+
@Since("2.2.0")
255+
override def read: MLReader[ImputerModel] = new ImputerReader
256+
257+
@Since("2.2.0")
258+
override def load(path: String): ImputerModel = super.load(path)
259+
}

0 commit comments

Comments
 (0)