Skip to content

Commit 6d7a675

Browse files
actuaryzhangWayne Zhang
authored andcommitted
[SPARK-20604][ML] Allow imputer to handle numeric types
## What changes were proposed in this pull request? Imputer currently requires input column to be Double or Float, but the logic should work on any numeric data types. Many practical problems have integer data types, and it could get very tedious to manually cast them into Double before calling imputer. This transformer could be extended to handle all numeric types. ## How was this patch tested? new test Closes #17864 from actuaryzhang/imputer. Lead-authored-by: actuaryzhang <[email protected]> Co-authored-by: Wayne Zhang <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 660423d commit 6d7a675

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
7373
s" and outputCols(${$(outputCols).length}) should have the same length")
7474
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
7575
val inputField = schema(inputCol)
76-
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
76+
SchemaUtils.checkNumericType(schema, inputCol)
7777
StructField(outputCol, inputField.dataType, inputField.nullable)
7878
}
7979
StructType(schema ++ outputFields)
@@ -84,9 +84,13 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
8484
* :: Experimental ::
8585
* Imputation estimator for completing missing values, either using the mean or the median
8686
* of the columns in which the missing values are located. The input columns should be of
87-
* DoubleType or FloatType. Currently Imputer does not support categorical features
87+
* numeric type. Currently Imputer does not support categorical features
8888
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
8989
*
90+
* Note when an input column is integer, the imputed value is casted (truncated) to an integer type.
91+
* For example, if the input column is IntegerType (1, 2, 4, null),
92+
* the output will be IntegerType (1, 2, 4, 2) after mean imputation.
93+
*
9094
* Note that the mean/median value is computed after filtering out missing values.
9195
* All Null values in the input columns are treated as missing, and so are also imputed. For
9296
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
@@ -218,7 +222,7 @@ class ImputerModel private[ml] (
218222
val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
219223
case ((inputCol, outputCol), surrogate) =>
220224
val inputType = dataset.schema(inputCol).dataType
221-
val ic = col(inputCol)
225+
val ic = col(inputCol).cast(DoubleType)
222226
when(ic.isNull, surrogate)
223227
.when(ic === $(missingValue), surrogate)
224228
.otherwise(ic)

mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import org.apache.spark.SparkException
2020
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
2121
import org.apache.spark.mllib.util.TestingUtils._
2222
import org.apache.spark.sql.{DataFrame, Row}
23+
import org.apache.spark.sql.functions._
24+
import org.apache.spark.sql.types._
2325

2426
class ImputerSuite extends MLTest with DefaultReadWriteTest {
2527

@@ -176,6 +178,48 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
176178
assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
177179
}
178180

181+
test("Imputer for IntegerType with default missing value null") {
182+
183+
val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
184+
(1, 1, 1),
185+
(11, 11, 11),
186+
(3, 3, 3),
187+
(null, 5, 3)
188+
)).toDF("value1", "expected_mean_value1", "expected_median_value1")
189+
190+
val imputer = new Imputer()
191+
.setInputCols(Array("value1"))
192+
.setOutputCols(Array("out1"))
193+
194+
val types = Seq(IntegerType, LongType)
195+
for (mType <- types) {
196+
// cast all columns to desired data type for testing
197+
val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*)
198+
ImputerSuite.iterateStrategyTest(imputer, df2)
199+
}
200+
}
201+
202+
test("Imputer for IntegerType with missing value -1") {
203+
204+
val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
205+
(1, 1, 1),
206+
(11, 11, 11),
207+
(3, 3, 3),
208+
(-1, 5, 3)
209+
)).toDF("value1", "expected_mean_value1", "expected_median_value1")
210+
211+
val imputer = new Imputer()
212+
.setInputCols(Array("value1"))
213+
.setOutputCols(Array("out1"))
214+
.setMissingValue(-1.0)
215+
216+
val types = Seq(IntegerType, LongType)
217+
for (mType <- types) {
218+
// cast all columns to desired data type for testing
219+
val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*)
220+
ImputerSuite.iterateStrategyTest(imputer, df2)
221+
}
222+
}
179223
}
180224

181225
object ImputerSuite {
@@ -190,13 +234,26 @@ object ImputerSuite {
190234
val model = imputer.fit(df)
191235
val resultDF = model.transform(df)
192236
imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) =>
237+
238+
// check dataType is consistent between input and output
239+
val inputType = resultDF.schema(inputCol).dataType
240+
val outputType = resultDF.schema(outputCol).dataType
241+
assert(inputType == outputType, "Output type is not the same as input type.")
242+
243+
// check value
193244
resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach {
194245
case Row(exp: Float, out: Float) =>
195246
assert((exp.isNaN && out.isNaN) || (exp == out),
196247
s"Imputed values differ. Expected: $exp, actual: $out")
197248
case Row(exp: Double, out: Double) =>
198249
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
199250
s"Imputed values differ. Expected: $exp, actual: $out")
251+
case Row(exp: Integer, out: Integer) =>
252+
assert(exp == out,
253+
s"Imputed values differ. Expected: $exp, actual: $out")
254+
case Row(exp: Long, out: Long) =>
255+
assert(exp == out,
256+
s"Imputed values differ. Expected: $exp, actual: $out")
200257
}
201258
}
202259
}

0 commit comments

Comments
 (0)