@@ -53,11 +53,11 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
5353
5454 test(" Imputer for Float with missing Value -1.0" ) {
5555 val df = spark.createDataFrame( Seq (
56- (0 , 1.0F , 1.0 , 1.0 ),
57- (1 , 3.0F , 3.0 , 3.0 ),
58- (2 , 10.0F , 10.0 , 10.0 ),
59- (3 , 10.0F , 10.0 , 10.0 ),
60- (4 , - 1.0F , 6.0 , 3.0 )
56+ (0 , 1.0F , 1.0F , 1.0F ),
57+ (1 , 3.0F , 3.0F , 3.0F ),
58+ (2 , 10.0F , 10.0F , 10.0F ),
59+ (3 , 10.0F , 10.0F , 10.0F ),
60+ (4 , - 1.0F , 6.0F , 3.0F )
6161 )).toDF(" id" , " value" , " expected_mean_value" , " expected_median_value" )
6262 val imputer = new Imputer ().setInputCols(Array (" value" )).setOutputCols(Array (" out" ))
6363 .setMissingValue(- 1 )
@@ -238,6 +238,9 @@ object ImputerSuite {
238238 val resultDF = model.transform(df)
239239 imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) =>
240240 resultDF.select(s " expected_ ${strategy}_ $inputCol" , outputCol).collect().foreach {
241+ case Row (exp : Float , out : Float ) =>
242+ assert((exp.isNaN && out.isNaN) || (exp == out),
243+ s " Imputed values differ. Expected: $exp, actual: $out" )
241244 case Row (exp : Double , out : Double ) =>
242245 assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5 ),
243246 s " Imputed values differ. Expected: $exp, actual: $out" )
0 commit comments