Skip to content

Conversation

@johnhany97
Copy link
Contributor

@johnhany97 johnhany97 commented Dec 2, 2019

What changes were proposed in this pull request?

Depend on type coercion when building the replace query. This would solve an edge case where when trying to replace NaNs, 0s would get replace too.

Why are the changes needed?

This Scala code snippet:

import scala.math;

println(Double.NaN.toLong)

returns 0 which is problematic as if you run the following Spark code, 0s get replaced as well:

>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+
>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    2|
|  0.0|    3|
|  2.0|    2|
+-----+-----+ 

Does this PR introduce any user-facing change?

Yes, after the PR, running the same above code snippet returns the correct expected results:

>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+

>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  2.0|    0|
+-----+-----+

And additionally, query results are changed as a result of the change in depending on scala's type coercion rules.

How was this patch tested?

Added unit tests to verify replacing NaN only affects columns of type Float and Double.

@johnhany97 johnhany97 changed the title [SPARK-30082] Do not replace Zeros when replacing NaNs [WIP][SPARK-30082] Do not replace Zeros when replacing NaNs Dec 2, 2019
@mccheah
Copy link
Contributor

mccheah commented Dec 2, 2019

Ok to test

@johnhany97 johnhany97 changed the title [WIP][SPARK-30082] Do not replace Zeros when replacing NaNs [SPARK-30082] Do not replace Zeros when replacing NaNs Dec 2, 2019
@mccheah
Copy link
Contributor

mccheah commented Dec 2, 2019

@dongjoon-hyun can you take a look?

@dongjoon-hyun dongjoon-hyun added SQL and removed PYSPARK labels Dec 3, 2019
@dongjoon-hyun dongjoon-hyun changed the title [SPARK-30082] Do not replace Zeros when replacing NaNs [SPARK-30082][SQL] Do not replace Zeros when replacing NaNs Dec 3, 2019
@dongjoon-hyun
Copy link
Member

Thank you for pinging me, @mccheah . Sure.

@dongjoon-hyun
Copy link
Member

ok to test

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114753 has finished for PR 26738 at commit ed6f08d.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114769 has finished for PR 26738 at commit 279a9fd.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114770 has finished for PR 26738 at commit 6b5d26d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114771 has finished for PR 26738 at commit 10c91d6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114772 has finished for PR 26738 at commit 1744b28.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in 8c2849a Dec 3, 2019
@cloud-fan
Copy link
Contributor

@johnhany97 do you mind opening another PR for 2.4? thanks!

@johnhany97
Copy link
Contributor Author

@cloud-fan sure thing

johnhany97 added a commit to johnhany97/spark that referenced this pull request Dec 3, 2019
Do not cast `NaN` to an `Integer`, `Long`, `Short` or `Byte`. This is because casting `NaN` to those types results in a `0` which erroneously replaces `0`s while only `NaN`s should be replaced.

This Scala code snippet:
```
import scala.math;

println(Double.NaN.toLong)
```
returns `0` which is problematic as if you run the following Spark code, `0`s get replaced as well:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+
>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    2|
|  0.0|    3|
|  2.0|    2|
+-----+-----+
```

Yes, after the PR, running the same above code snippet returns the correct expected results:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+

>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  2.0|    0|
+-----+-----+
```

Added unit tests to verify replacing `NaN` only affects columns of type `Float` and `Double`

Closes apache#26738 from johnhany97/SPARK-30082.

Lead-authored-by: John Ayad <[email protected]>
Co-authored-by: John Ayad <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
@dongjoon-hyun
Copy link
Member

Thank you, @johnhany97 , @mccheah , @cloud-fan .

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114776 has finished for PR 26738 at commit 1295633.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114781 has finished for PR 26738 at commit b3709a1.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

attilapiros pushed a commit to attilapiros/spark that referenced this pull request Dec 6, 2019
### What changes were proposed in this pull request?
Do not cast `NaN` to an `Integer`, `Long`, `Short` or `Byte`. This is because casting `NaN` to those types results in a `0` which erroneously replaces `0`s while only `NaN`s should be replaced.

### Why are the changes needed?
This Scala code snippet:
```
import scala.math;

println(Double.NaN.toLong)
```
returns `0` which is problematic as if you run the following Spark code, `0`s get replaced as well:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+
>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    2|
|  0.0|    3|
|  2.0|    2|
+-----+-----+
```

### Does this PR introduce any user-facing change?
Yes, after the PR, running the same above code snippet returns the correct expected results:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+

>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  2.0|    0|
+-----+-----+
```

### How was this patch tested?

Added unit tests to verify replacing `NaN` only affects columns of type `Float` and `Double`

Closes apache#26738 from johnhany97/SPARK-30082.

Lead-authored-by: John Ayad <[email protected]>
Co-authored-by: John Ayad <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
Copy link
Member

@gatorsmile gatorsmile left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change's impact is much bigger than what it documents in the PR description.

def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
Seq(Literal(source), buildExpr(target))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix relies on the type coercion rule to do the casting in the another side. It could cause the difference of query results. For example,

 def createNaNDF(): DataFrame = {
   Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
     java.lang.Byte, java.lang.Float, java.lang.Double)](
     (2, 2L, 2.toShort, 2.toByte, 2.0f, 2.0)
   ).toDF("int", "long", "short", "byte", "float", "double")
 }

 test("replace float with double") {
   createNaNDF().na.replace("*", Map(
     2.3 -> 9.0
   )).show()

   createNaNDF().na.replace("*", Map(
     2.3 -> 9.0
   )).explain(true)
 }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR,

+---+----+-----+----+-----+------+
|int|long|short|byte|float|double|
+---+----+-----+----+-----+------+
|  9|   9|    9|   9|  2.0|   2.0|
+---+----+-----+----+-----+------+

== Parsed Logical Plan ==
Project [CASE WHEN (int#99 = cast(2.3 as int)) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (long#100L = cast(2.3 as bigint)) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118L, CASE WHEN (short#101 = cast(2.3 as smallint)) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (byte#102 = cast(2.3 as tinyint)) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (float#103 = cast(2.3 as float)) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = cast(2.3 as double)) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Analyzed Logical Plan ==
int: int, long: bigint, short: smallint, byte: tinyint, float: float, double: double
Project [CASE WHEN (int#99 = cast(2.3 as int)) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (long#100L = cast(2.3 as bigint)) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118L, CASE WHEN (short#101 = cast(2.3 as smallint)) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (byte#102 = cast(2.3 as tinyint)) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (float#103 = cast(2.3 as float)) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = cast(2.3 as double)) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Optimized Logical Plan ==
Project [CASE WHEN (_1#86 = 2) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (_2#87L = 2) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (_3#88 = 2) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (_4#89 = 2) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (_5#90 = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Physical Plan ==
*(1) Project [CASE WHEN (_1#86 = 2) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (_2#87L = 2) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (_3#88 = 2) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (_4#89 = 2) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (_5#90 = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- *(1) LocalTableScan [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

After this PR,

+---+----+-----+----+-----+------+
|int|long|short|byte|float|double|
+---+----+-----+----+-----+------+
|  2|   2|    2|   2|  2.0|   2.0|
+---+----+-----+----+-----+------+

== Parsed Logical Plan ==
'Project [CASE WHEN (int#99 = 2.3) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (long#100L = 2.3) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118, CASE WHEN (short#101 = 2.3) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (byte#102 = 2.3) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (float#103 = 2.3) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = 2.3) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Analyzed Logical Plan ==
int: int, long: bigint, short: smallint, byte: tinyint, float: float, double: double
Project [CASE WHEN (cast(int#99 as double) = 2.3) THEN cast(9.0 as int) ELSE int#99 END AS int#117, CASE WHEN (cast(long#100L as double) = 2.3) THEN cast(9.0 as bigint) ELSE long#100L END AS long#118L, CASE WHEN (cast(short#101 as double) = 2.3) THEN cast(9.0 as smallint) ELSE short#101 END AS short#119, CASE WHEN (cast(byte#102 as double) = 2.3) THEN cast(9.0 as tinyint) ELSE byte#102 END AS byte#120, CASE WHEN (cast(float#103 as double) = 2.3) THEN cast(9.0 as float) ELSE float#103 END AS float#121, CASE WHEN (double#104 = 2.3) THEN cast(9.0 as double) ELSE double#104 END AS double#122]
+- Project [_1#86 AS int#99, _2#87L AS long#100L, _3#88 AS short#101, _4#89 AS byte#102, _5#90 AS float#103, _6#91 AS double#104]
   +- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Optimized Logical Plan ==
Project [CASE WHEN (cast(_1#86 as double) = 2.3) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (cast(_2#87L as double) = 2.3) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (cast(_3#88 as double) = 2.3) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (cast(_4#89 as double) = 2.3) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (cast(_5#90 as double) = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- LocalRelation [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

== Physical Plan ==
*(1) Project [CASE WHEN (cast(_1#86 as double) = 2.3) THEN 9 ELSE _1#86 END AS int#117, CASE WHEN (cast(_2#87L as double) = 2.3) THEN 9 ELSE _2#87L END AS long#118L, CASE WHEN (cast(_3#88 as double) = 2.3) THEN 9 ELSE _3#88 END AS short#119, CASE WHEN (cast(_4#89 as double) = 2.3) THEN 9 ELSE _4#89 END AS byte#120, CASE WHEN (cast(_5#90 as double) = 2.3) THEN 9.0 ELSE _5#90 END AS float#121, CASE WHEN (_6#91 = 2.3) THEN 9.0 ELSE _6#91 END AS double#122]
+- *(1) LocalTableScan [_1#86, _2#87L, _3#88, _4#89, _5#90, _6#91]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new behavior makes more sense, but I agree that the PR description needs update to reflect all the changes. cc @johnhany97

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch there @gatorsmile. I've updated the PR description. Should I also update the PR title? Let me know if you'd like me to add in more details into the PR description.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we also need to update the PR title.

@johnhany97 johnhany97 changed the title [SPARK-30082][SQL] Do not replace Zeros when replacing NaNs [SPARK-30082][SQL] Depend on Scala type coercion when building replace query Jan 4, 2020
bulldozer-bot bot pushed a commit to palantir/spark that referenced this pull request Jan 15, 2020
…e query (#628)

apache#26738
apache#26749

### What changes were proposed in this pull request?
Depend on type coercion when building the replace query. This would solve an edge case where when trying to replace `NaN`s, `0`s would get replace too.

### Why are the changes needed?
This Scala code snippet:
```
import scala.math;

println(Double.NaN.toLong)
```
returns `0` which is problematic as if you run the following Spark code, `0`s get replaced as well:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+
>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    2|
|  0.0|    3|
|  2.0|    2|
+-----+-----+ 
```

### Does this PR introduce any user-facing change?
Yes, after the PR, running the same above code snippet returns the correct expected results:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+

>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  2.0|    0|
+-----+-----+
```
And additionally, query results are changed as a result of the change in depending on scala's type coercion rules.

### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why it was difficult to add.
-->
Added unit tests to verify replacing `NaN` only affects columns of type `Float` and `Double`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants