Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
Seq(Literal(source), buildExpr(target))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix relies on the type coercion rule to do the casting in the another side. It could cause the difference of query results. For example,

 def createNaNDF(): DataFrame = {
   Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
     java.lang.Byte, java.lang.Float, java.lang.Double)](
     (2, 2L, 2.toShort, 2.toByte, 2.0f, 2.0)
   ).toDF("int", "long", "short", "byte", "float", "double")
 }

 test("replace float with double") {
   createNaNDF().na.replace("*", Map(
     2.3 -> 9.0
   )).show()

   createNaNDF().na.replace("*", Map(
     2.3 -> 9.0
   )).explain(true)
 }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR,

+---+----+-----+----+-----+------+
|int|long|short|byte|float|double|
+---+----+-----+----+-----+------+
|  9|   9|    9|   9|  2.0|   2.0|
+---+----+-----+----+-----+------+

== Parsed Logical Plan ==
Project [CASE WHEN (int#99 = cast(2.3 as int)) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (long#100L = cast(2.3 as bigint)) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118L, CASE WHEN (short#101 = cast(2.3 as smallint)) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (byte#102 = cast(2.3 as tinyint)) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (float#103 = cast(2.3 as float)) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = cast(2.3 as double)) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Analyzed Logical Plan ==
int: int, long: bigint, short: smallint, byte: tinyint, float: float, double: double
Project [CASE WHEN (int#99 = cast(2.3 as int)) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (long#100L = cast(2.3 as bigint)) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118L, CASE WHEN (short#101 = cast(2.3 as smallint)) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (byte#102 = cast(2.3 as tinyint)) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (float#103 = cast(2.3 as float)) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = cast(2.3 as double)) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Optimized Logical Plan ==
Project [CASE WHEN (_1#86 = 2) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (_2#87L = 2) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (_3#88 = 2) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (_4#89 = 2) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (_5#90 = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Physical Plan ==
*(1) Project [CASE WHEN (_1#86 = 2) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (_2#87L = 2) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (_3#88 = 2) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (_4#89 = 2) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (_5#90 = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- *(1) LocalTableScan [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

After this PR,

+---+----+-----+----+-----+------+
|int|long|short|byte|float|double|
+---+----+-----+----+-----+------+
|  2|   2|    2|   2|  2.0|   2.0|
+---+----+-----+----+-----+------+

== Parsed Logical Plan ==
'Project [CASE WHEN (int#99 = 2.3) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (long#100L = 2.3) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118, CASE WHEN (short#101 = 2.3) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (byte#102 = 2.3) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (float#103 = 2.3) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = 2.3) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Analyzed Logical Plan ==
int: int, long: bigint, short: smallint, byte: tinyint, float: float, double: double
Project [CASE WHEN (cast(int#99 as double) = 2.3) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (cast(long#100L as double) = 2.3) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118L, CASE WHEN (cast(short#101 as double) = 2.3) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (cast(byte#102 as double) = 2.3) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (cast(float#103 as double) = 2.3) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = 2.3) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Optimized Logical Plan ==
Project [CASE WHEN (cast(_1#86 as double) = 2.3) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (cast(_2#87L as double) = 2.3) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (cast(_3#88 as double) = 2.3) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (cast(_4#89 as double) = 2.3) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (cast(_5#90 as double) = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Physical Plan ==
*(1) Project [CASE WHEN (cast(_1#86 as double) = 2.3) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (cast(_2#87L as double) = 2.3) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (cast(_3#88 as double) = 2.3) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (cast(_4#89 as double) = 2.3) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (cast(_5#90 as double) = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- *(1) LocalTableScan [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new behavior makes more sense, but I agree that the PR description needs update to reflect all the changes. cc @johnhany97

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch there @gatorsmile. I've updated the PR description. Should I also update the PR title? Let me know if you'd like me to add in more details into the PR description.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we also need to update the PR title.

}.toSeq
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
).toDF("name", "age", "height")
}

def createNaNDF(): DataFrame = {
Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
java.lang.Byte, java.lang.Float, java.lang.Double)](
(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0),
(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN)
).toDF("int", "long", "short", "byte", "float", "double")
}

test("drop") {
val input = createDF()
val rows = input.collect()
Expand Down Expand Up @@ -404,4 +412,40 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
df.na.drop("any"),
Row("5", "6", "6") :: Nil)
}

test("replace nan with float") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Float.NaN -> 10.0f
)),
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
}

test("replace nan with double") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Double.NaN -> 10.0
)),
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
}

test("replace float with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0f -> Float.NaN
)),
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}

test("replace double with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0 -> Double.NaN
)),
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}
}