Skip to content

Commit e9ab39c

Browse files
author
Wayne Zhang
committed
allow imputer to handle numeric types
1 parent ba76662 commit e9ab39c

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
8383
s" and outputCols(${$(outputCols).length}) should have the same length")
8484
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
8585
val inputField = schema(inputCol)
86-
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
86+
SchemaUtils.checkNumericType(schema, inputCol)
8787
StructField(outputCol, inputField.dataType, inputField.nullable)
8888
}
8989
StructType(schema ++ outputFields)

mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
2121
import org.apache.spark.mllib.util.MLlibTestSparkContext
2222
import org.apache.spark.mllib.util.TestingUtils._
2323
import org.apache.spark.sql.{DataFrame, Row}
24+
import org.apache.spark.sql.functions._
25+
import org.apache.spark.sql.types._
2426

2527
class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
2628

@@ -155,6 +157,27 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
155157
assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
156158
}
157159

160+
test("Imputer for Numeric with default missing Value NaN") {
161+
val df = spark.createDataFrame( Seq(
162+
(0, 1.0, 1.0, 1.0),
163+
(1, 11.0, 11.0, 11.0),
164+
(2, 1.5, 1.5, 1.5),
165+
(3, Double.NaN, 4.5, 1.5)
166+
)).toDF("id", "value1", "expected_mean_value1", "expected_median_value1")
167+
val imputer = new Imputer()
168+
.setInputCols(Array("value1"))
169+
.setOutputCols(Array("out1"))
170+
171+
val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
172+
ByteType, DecimalType(10, 0))
173+
for (mType <- types) {
174+
val df2 = df.withColumn("value1", col("value1").cast(mType))
175+
.withColumn("value1", when(col("value1").equalTo(0), null).otherwise(col("value1")))
176+
.withColumn("expected_mean_value1", col("expected_mean_value1").cast(mType))
177+
.withColumn("expected_median_value1", col("expected_median_value1").cast(mType))
178+
ImputerSuite.iterateStrategyTest(imputer, df2)
179+
}
180+
}
158181
}
159182

160183
object ImputerSuite {
@@ -178,6 +201,9 @@ object ImputerSuite {
178201
case Row(exp: Double, out: Double) =>
179202
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
180203
s"Imputed values differ. Expected: $exp, actual: $out")
204+
case Row(exp, out) =>
205+
assert(exp == out,
206+
s"Imputed values differ. Expected: $exp, actual: $out")
181207
}
182208
}
183209
}

0 commit comments

Comments
 (0)