Skip to content

Commit fa1abe2

Browse files
cloud-fandongjoon-hyun
authored andcommitted
Revert [SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float
This PR reverts #23043 and its followup #23265, from branch 2.4, because it has behavior changes. existing tests Closes #23389 from cloud-fan/revert. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent c2bff77 commit fa1abe2

File tree

7 files changed

+7
-93
lines changed

7 files changed

+7
-93
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ public void setLong(int ordinal, long value) {
224224
public void setDouble(int ordinal, double value) {
225225
assertIndexIsValid(ordinal);
226226
setNotNullAt(ordinal);
227+
if (Double.isNaN(value)) {
228+
value = Double.NaN;
229+
}
227230
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
228231
}
229232

@@ -252,6 +255,9 @@ public void setByte(int ordinal, byte value) {
252255
public void setFloat(int ordinal, float value) {
253256
assertIndexIsValid(ordinal);
254257
setNotNullAt(ordinal);
258+
if (Float.isNaN(value)) {
259+
value = Float.NaN;
260+
}
255261
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
256262
}
257263

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -198,45 +198,16 @@ protected final void writeLong(long offset, long value) {
198198
Platform.putLong(getBuffer(), offset, value);
199199
}
200200

201-
// We need to take care of NaN and -0.0 in several places:
202-
// 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
203-
// treated as same.
204-
// 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
205-
// to the same group.
206-
// 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
207-
// treated as same.
208-
// 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
209-
// should be treated as same.
210-
//
211-
// Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
212-
// recursively compare the fields/elements, so it's also fine.
213-
//
214-
// Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
215-
// NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
216-
//
217-
// Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
218-
// float/double columns and nested fields to `UnsafeRow`.
219-
//
220-
// Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
221-
// join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
222-
// types, so nested float/double may not be normalized. We need to make sure that all the unsafe
223-
// data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
224-
// creation.
225201
protected final void writeFloat(long offset, float value) {
226202
if (Float.isNaN(value)) {
227203
value = Float.NaN;
228-
} else if (value == -0.0f) {
229-
value = 0.0f;
230204
}
231205
Platform.putFloat(getBuffer(), offset, value);
232206
}
233207

234-
// See comments for `writeFloat`.
235208
protected final void writeDouble(long offset, double value) {
236209
if (Double.isNaN(value)) {
237210
value = Double.NaN;
238-
} else if (value == -0.0d) {
239-
value = 0.0d;
240211
}
241212
Platform.putDouble(getBuffer(), offset, value);
242213
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,4 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
5050
assert(res1 == res2)
5151
}
5252

53-
test("SPARK-26021: normalize float/double NaN and -0.0") {
54-
val unsafeRowWriter1 = new UnsafeRowWriter(4)
55-
unsafeRowWriter1.resetRowWriter()
56-
unsafeRowWriter1.write(0, Float.NaN)
57-
unsafeRowWriter1.write(1, Double.NaN)
58-
unsafeRowWriter1.write(2, 0.0f)
59-
unsafeRowWriter1.write(3, 0.0)
60-
val res1 = unsafeRowWriter1.getRow
61-
62-
val unsafeRowWriter2 = new UnsafeRowWriter(4)
63-
unsafeRowWriter2.resetRowWriter()
64-
unsafeRowWriter2.write(0, 0.0f/0.0f)
65-
unsafeRowWriter2.write(1, 0.0/0.0)
66-
unsafeRowWriter2.write(2, -0.0f)
67-
unsafeRowWriter2.write(3, -0.0)
68-
val res2 = unsafeRowWriter2.getRow
69-
70-
// The two rows should be the equal
71-
assert(res1 == res2)
72-
}
7353
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -727,18 +727,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
727727
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
728728
"type: GroupBy]"))
729729
}
730-
731-
test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") {
732-
val colName = "i"
733-
val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect()
734-
val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect()
735-
736-
assert(doubles.length == 1)
737-
assert(floats.length == 1)
738-
// using compare since 0.0 == -0.0 is true
739-
assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
740-
assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
741-
assert(doubles(0).getLong(1) == 3)
742-
assert(floats(0).getLong(1) == 3)
743-
}
744730
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,4 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
295295
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
296296
}
297297
}
298-
299-
test("NaN and -0.0 in join keys") {
300-
val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
301-
val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
302-
val joined = df1.join(df2, Seq("f", "d"))
303-
checkAnswer(joined, Seq(
304-
Row(Float.NaN, Double.NaN),
305-
Row(0.0f, 0.0),
306-
Row(0.0f, 0.0),
307-
Row(0.0f, 0.0),
308-
Row(0.0f, 0.0)))
309-
}
310298
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -658,18 +658,4 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
658658
|GROUP BY a
659659
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
660660
}
661-
662-
test("NaN and -0.0 in window partition keys") {
663-
val df = Seq(
664-
(Float.NaN, Double.NaN, 1),
665-
(0.0f/0.0f, 0.0/0.0, 1),
666-
(0.0f, 0.0, 1),
667-
(-0.0f, -0.0, 1)).toDF("f", "d", "i")
668-
val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
669-
checkAnswer(result, Seq(
670-
Row(Float.NaN, 2),
671-
Row(Float.NaN, 2),
672-
Row(0.0f, 2),
673-
Row(0.0f, 2)))
674-
}
675661
}

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ object QueryTest {
289289
def prepareRow(row: Row): Row = {
290290
Row.fromSeq(row.toSeq.map {
291291
case null => null
292-
case bd: java.math.BigDecimal => BigDecimal(bd)
292+
case d: java.math.BigDecimal => BigDecimal(d)
293293
// Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+
294294
case seq: Seq[_] => seq.map {
295295
case b: java.lang.Byte => b.byteValue
@@ -303,9 +303,6 @@ object QueryTest {
303303
// Convert array to Seq for easy equality check.
304304
case b: Array[_] => b.toSeq
305305
case r: Row => prepareRow(r)
306-
// spark treats -0.0 as 0.0
307-
case d: Double if d == -0.0d => 0.0d
308-
case f: Float if f == -0.0f => 0.0f
309306
case o => o
310307
})
311308
}

0 commit comments

Comments
 (0)