@@ -21,6 +21,8 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
2121import org .apache .spark .mllib .util .MLlibTestSparkContext
2222import org .apache .spark .mllib .util .TestingUtils ._
2323import org .apache .spark .sql .{DataFrame , Row }
24+ import org .apache .spark .sql .functions ._
25+ import org .apache .spark .sql .types ._
2426
2527class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
2628
@@ -155,6 +157,27 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
155157 assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
156158 }
157159
160+ test(" Imputer for Numeric with default missing Value NaN" ) {
161+ val df = spark.createDataFrame( Seq (
162+ (0 , 1.0 , 1.0 , 1.0 ),
163+ (1 , 11.0 , 11.0 , 11.0 ),
164+ (2 , 1.5 , 1.5 , 1.5 ),
165+ (3 , Double .NaN , 4.5 , 1.5 )
166+ )).toDF(" id" , " value1" , " expected_mean_value1" , " expected_median_value1" )
167+ val imputer = new Imputer ()
168+ .setInputCols(Array (" value1" ))
169+ .setOutputCols(Array (" out1" ))
170+
171+ val types = Seq (ShortType , IntegerType , LongType , FloatType , DoubleType ,
172+ ByteType , DecimalType (10 , 0 ))
173+ for (mType <- types) {
174+ val df2 = df.withColumn(" value1" , col(" value1" ).cast(mType))
175+ .withColumn(" value1" , when(col(" value1" ).equalTo(0 ), null ).otherwise(col(" value1" )))
176+ .withColumn(" expected_mean_value1" , col(" expected_mean_value1" ).cast(mType))
177+ .withColumn(" expected_median_value1" , col(" expected_median_value1" ).cast(mType))
178+ ImputerSuite .iterateStrategyTest(imputer, df2)
179+ }
180+ }
158181}
159182
160183object ImputerSuite {
@@ -178,6 +201,9 @@ object ImputerSuite {
178201 case Row (exp : Double , out : Double ) =>
179202 assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5 ),
180203 s " Imputed values differ. Expected: $exp, actual: $out" )
204+ case Row (exp, out) =>
205+ assert(exp == out,
206+ s " Imputed values differ. Expected: $exp, actual: $out" )
181207 }
182208 }
183209 }
0 commit comments