From 1f6b2594ebe8b50e4cb2fcde15181cfa9a17f48c Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 14 Aug 2018 12:04:40 +0800 Subject: [PATCH 1/4] fix --- .../spark/sql/execution/RecordBinaryComparator.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2a..5148b726a32ef 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -27,7 +27,7 @@ public final class RecordBinaryComparator extends RecordComparator { public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; + long res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -42,16 +42,16 @@ public int compare( while ((leftOff + i) % 8 != 0 && i < leftLen) { res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + if (res != 0) return (int) res; i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + res = Platform.getLong(leftObj, leftOff + i) - + Platform.getLong(rightObj, rightOff + i); + if (res != 0) return res > 0 ? 1 : -1; i += 8; } } @@ -60,7 +60,7 @@ public int compare( while (i < leftLen) { res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + if (res != 0) return (int) res; i += 1; } From ddcfea311a1e8ae4eac580280b5e89f5b5f48832 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 14 Aug 2018 19:19:40 +0800 Subject: [PATCH 2/4] update --- .../sql/execution/RecordBinaryComparator.java | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index 5148b726a32ef..334ed0775b0ea 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -27,7 +27,6 @@ public final class RecordBinaryComparator extends RecordComparator { public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - long res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +39,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return (int) res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i); - if (res != 0) return res > 0 ? 1 : -1; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return (int) res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } From 9c1f4860ea794e60a127ae3ec2d1b518a182edf5 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 15 Aug 2018 22:40:33 +0800 Subject: [PATCH 3/4] add test cases --- .../sql/execution/RecordBinaryComparator.java | 1 - .../sort/RecordBinaryComparatorSuite.java | 40 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index 334ed0775b0ea..40c2cc806e87a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,7 +22,6 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index a19ddbdbadba2..5b53f6dae41f8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -253,4 +253,44 @@ public void testBinaryComparatorForNullColumns() throws Exception { assert(compare(0, 0) == 0); assert(compare(0, 1) > 0); } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } } From ecb26fc902995f866ee837f48c656cfb2174f18d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 20 Aug 2018 17:03:33 +0800 Subject: [PATCH 4/4] update --- .../sort/RecordBinaryComparatorSuite.java | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index 5b53f6dae41f8..97f3dc588ecc5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -293,4 +293,30 @@ public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exc assert(compare(0, 1) < 0); } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } }