From 7d5ff06969cd1fe8d05adcf8c79d6b56d0865f44 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Dec 2018 23:05:39 +0800 Subject: [PATCH 1/3] only deal with NaN and -0.0 in UnsafeWriter --- .../org/apache/spark/unsafe/Platform.java | 10 ------ .../expressions/codegen/UnsafeWriter.java | 34 +++++++++++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 4563efcfcf474..076b693f81c88 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -174,11 +174,6 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } _UNSAFE.putFloat(object, offset, value); } @@ -187,11 +182,6 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } _UNSAFE.putDouble(object, offset, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 95263a0da95a8..df4e9373d508a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -198,11 +198,45 @@ protected final void writeLong(long offset, long value) { Platform.putLong(getBuffer(), offset, value); } + // We need to take care of NaN and -0.0 in several places: + // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be + // treated as same. + // 2. In range partitioner, different NaNs should belong to the same partition, -0.0 and 0.0 + // should belong to the same partition. + // 3. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong + // to the same group. + // 4. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be + // treated as same. + // + // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we + // recursively compare the fields/elements, so it's also fine. + // + // Case 2 is problematic, as the sorter of range partitioner uses prefix comparator, which + // thinks 0.0 > -0.0. + // + // Case 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs + // have different binary representation, and the same thing happens for -0.0 and 0.0. + // + // Here we normalize NaN and -0.0, so that we don't need to care about NaN and -0.0 for + // `UnsafeRow`s created by `UnsafeProjection`. It fixes case 2, because the input of the sorter + // of the range partitioner are `UnsafeRow`s created by `UnsafeProjection`. It also fixes case 3 + // and 4, because Spark uses `UnsafeProjection` to extract join keys/grouping keys. protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } Platform.putFloat(getBuffer(), offset, value); } + // See comments for `writeFloat`. protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } Platform.putDouble(getBuffer(), offset, value); } } From c9dfe67ee2ff6a2ac5478e072c2709c1b0d0e5f4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Dec 2018 00:53:16 +0800 Subject: [PATCH 2/3] address comments --- .../spark/unsafe/PlatformUtilSuite.java | 18 ------------- .../expressions/codegen/UnsafeWriter.java | 25 ++++++++++--------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 2474081dad5c9..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,22 +157,4 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } - - @Test - // SPARK-26021 - public void writeMinusZeroIsReplacedWithZero() { - byte[] doubleBytes = new byte[Double.BYTES]; - byte[] floatBytes = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - - byte[] doubleBytes2 = new byte[Double.BYTES]; - byte[] floatBytes2 = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f); - - // Make sure the bytes we write from 0.0 and -0.0 are same. - Assert.assertArrayEquals(doubleBytes, doubleBytes2); - Assert.assertArrayEquals(floatBytes, floatBytes2); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index df4e9373d508a..7553ab8cf7000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -201,26 +201,27 @@ protected final void writeLong(long offset, long value) { // We need to take care of NaN and -0.0 in several places: // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be // treated as same. - // 2. In range partitioner, different NaNs should belong to the same partition, -0.0 and 0.0 - // should belong to the same partition. - // 3. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong + // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong // to the same group. - // 4. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be + // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be // treated as same. + // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` + // should be treated as same. // // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we // recursively compare the fields/elements, so it's also fine. // - // Case 2 is problematic, as the sorter of range partitioner uses prefix comparator, which - // thinks 0.0 > -0.0. + // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different + // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. // - // Case 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs - // have different binary representation, and the same thing happens for -0.0 and 0.0. + // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing + // float/double columns and nested fields to `UnsafeRow`. // - // Here we normalize NaN and -0.0, so that we don't need to care about NaN and -0.0 for - // `UnsafeRow`s created by `UnsafeProjection`. It fixes case 2, because the input of the sorter - // of the range partitioner are `UnsafeRow`s created by `UnsafeProjection`. It also fixes case 3 - // and 4, because Spark uses `UnsafeProjection` to extract join keys/grouping keys. + // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract + // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex + // types, so nested float/double may not be normalized. We need to make sure that all the unsafe + // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during + // creation. protected final void writeFloat(long offset, float value) { if (Float.isNaN(value)) { value = Float.NaN; From b7a54979dbbb905bcda2e073697bb4ec82c3f93d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Dec 2018 16:44:01 +0800 Subject: [PATCH 3/3] add more tests --- .../codegen/UnsafeRowWriterSuite.scala | 20 +++++++++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 12 +++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 14 +++++++++++++ 3 files changed, 46 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index fb651b76fc16d..22e1fa6dfed4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite { assert(res1 == res2) } + test("SPARK-26021: normalize float/double NaN and -0.0") { + val unsafeRowWriter1 = new UnsafeRowWriter(4) + unsafeRowWriter1.resetRowWriter() + unsafeRowWriter1.write(0, Float.NaN) + unsafeRowWriter1.write(1, Double.NaN) + unsafeRowWriter1.write(2, 0.0f) + unsafeRowWriter1.write(3, 0.0) + val res1 = unsafeRowWriter1.getRow + + val unsafeRowWriter2 = new UnsafeRowWriter(4) + unsafeRowWriter2.resetRowWriter() + unsafeRowWriter2.write(0, 0.0f/0.0f) + unsafeRowWriter2.write(1, 0.0/0.0) + unsafeRowWriter2.write(2, -0.0f) + unsafeRowWriter2.write(3, -0.0) + val res2 = unsafeRowWriter2.getRow + + // The two rows should be the equal + assert(res1 == res2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e6b30f9956daf..c9f41ab1c0179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } + + test("NaN and -0.0 in join keys") { + val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") + val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") + val joined = df1.join(df2, Seq("f", "d")) + checkAnswer(joined, Seq( + Row(Float.NaN, Double.NaN), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(0.0f, 0.0))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78277d7dcf757..9a5d5a9966ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -681,4 +681,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("S2", "P2", 300, 300, 500))) } + + test("NaN and -0.0 in window partition keys") { + val df = Seq( + (Float.NaN, Double.NaN, 1), + (0.0f/0.0f, 0.0/0.0, 1), + (0.0f, 0.0, 1), + (-0.0f, -0.0, 1)).toDF("f", "d", "i") + val result = df.select($"f", count("i").over(Window.partitionBy("f", "d"))) + checkAnswer(result, Seq( + Row(Float.NaN, 2), + Row(Float.NaN, 2), + Row(0.0f, 2), + Row(0.0f, 2))) + } }