Skip to content

Commit 74d974a

Browse files
ulysses-yousrowen
authored andcommitted
[SPARK-37037][SQL] Improve byte array sort by unify compareTo function of UTF8String and ByteArray
### What changes were proposed in this pull request? Unify the compare function of `UTF8String` and `ByteArray`. ### Why are the changes needed? `BinaryType` use `TypeUtils.compareBinary` to compare two byte array, however it's slow since it compares byte array using unsigned int comparison byte by bye. We can compare them using `Platform.getLong` with unsigned long comparison if they have more than 8 bytes. And here is some histroy about this `TODO` https://github.com/apache/spark/pull/6755/files#r32197461 The benchmark result should be same with `UTF8String`, can be found in #19180 (#19180 (comment)) ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Move test from `TypeUtilsSuite` to `ByteArraySuite` Closes #34310 from ulysses-you/SPARK-37037. Authored-by: ulysses-you <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 5c28b6e commit 74d974a

File tree

10 files changed

+178
-55
lines changed

10 files changed

+178
-55
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,42 @@ static long getPrefix(Object base, long offset, int numBytes) {
7575
return (IS_LITTLE_ENDIAN ? java.lang.Long.reverseBytes(p) : p) & ~mask;
7676
}
7777

78+
public static int compareBinary(byte[] leftBase, byte[] rightBase) {
79+
return compareBinary(leftBase, Platform.BYTE_ARRAY_OFFSET, leftBase.length,
80+
rightBase, Platform.BYTE_ARRAY_OFFSET, rightBase.length);
81+
}
82+
83+
static int compareBinary(
84+
Object leftBase,
85+
long leftOffset,
86+
int leftNumBytes,
87+
Object rightBase,
88+
long rightOffset,
89+
int rightNumBytes) {
90+
int len = Math.min(leftNumBytes, rightNumBytes);
91+
int wordMax = (len / 8) * 8;
92+
for (int i = 0; i < wordMax; i += 8) {
93+
long left = Platform.getLong(leftBase, leftOffset + i);
94+
long right = Platform.getLong(rightBase, rightOffset + i);
95+
if (left != right) {
96+
if (IS_LITTLE_ENDIAN) {
97+
return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right));
98+
} else {
99+
return Long.compareUnsigned(left, right);
100+
}
101+
}
102+
}
103+
for (int i = wordMax; i < len; i++) {
104+
// Both UTF-8 and byte array should be compared as unsigned int.
105+
int res = (Platform.getByte(leftBase, leftOffset + i) & 0xFF) -
106+
(Platform.getByte(rightBase, rightOffset + i) & 0xFF);
107+
if (res != 0) {
108+
return res;
109+
}
110+
}
111+
return leftNumBytes - rightNumBytes;
112+
}
113+
78114
public static byte[] subStringSQL(byte[] bytes, int pos, int len) {
79115
// This pos calculation is according to UTF8String#subStringSQL
80116
if (pos > bytes.length) {

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,29 +1353,8 @@ public UTF8String copy() {
13531353

13541354
@Override
13551355
public int compareTo(@Nonnull final UTF8String other) {
1356-
int len = Math.min(numBytes, other.numBytes);
1357-
int wordMax = (len / 8) * 8;
1358-
long roffset = other.offset;
1359-
Object rbase = other.base;
1360-
for (int i = 0; i < wordMax; i += 8) {
1361-
long left = getLong(base, offset + i);
1362-
long right = getLong(rbase, roffset + i);
1363-
if (left != right) {
1364-
if (IS_LITTLE_ENDIAN) {
1365-
return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right));
1366-
} else {
1367-
return Long.compareUnsigned(left, right);
1368-
}
1369-
}
1370-
}
1371-
for (int i = wordMax; i < len; i++) {
1372-
// In UTF-8, the byte should be unsigned, so we should compare them as unsigned int.
1373-
int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF);
1374-
if (res != 0) {
1375-
return res;
1376-
}
1377-
}
1378-
return numBytes - other.numBytes;
1356+
return ByteArray.compareBinary(
1357+
base, offset, numBytes, other.base, other.offset, other.numBytes);
13791358
}
13801359

13811360
public int compare(final UTF8String other) {

common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,23 @@ public void testGetPrefix() {
4848
Assert.assertEquals(result, expected);
4949
}
5050
}
51+
52+
@Test
53+
public void testCompareBinary() {
54+
byte[] x1 = new byte[0];
55+
byte[] y1 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
56+
assert(ByteArray.compareBinary(x1, y1) < 0);
57+
58+
byte[] x2 = new byte[]{(byte) 200, (byte) 100};
59+
byte[] y2 = new byte[]{(byte) 100, (byte) 100};
60+
assert(ByteArray.compareBinary(x2, y2) > 0);
61+
62+
byte[] x3 = new byte[]{(byte) 100, (byte) 200, (byte) 12};
63+
byte[] y3 = new byte[]{(byte) 100, (byte) 200};
64+
assert(ByteArray.compareBinary(x3, y3) > 0);
65+
66+
byte[] x4 = new byte[]{(byte) 100, (byte) 200};
67+
byte[] y4 = new byte[]{(byte) 100, (byte) 200};
68+
assert(ByteArray.compareBinary(x4, y4) == 0);
69+
}
5170
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ class CodegenContext extends Logging {
650650
s"$clsName.compareFloats($c1, $c2)"
651651
// use c1 - c2 may overflow
652652
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
653-
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
653+
case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)"
654654
case NullType => "0"
655655
case array: ArrayType =>
656656
val elementType = array.elementType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,6 @@ object TypeUtils {
8787
}
8888
}
8989

90-
def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
91-
val limit = if (x.length <= y.length) x.length else y.length
92-
var i = 0
93-
while (i < limit) {
94-
val res = (x(i) & 0xff) - (y(i) & 0xff)
95-
if (res != 0) return res
96-
i += 1
97-
}
98-
x.length - y.length
99-
}
100-
10190
/**
10291
* Returns true if the equals method of the elements of the data type is implemented properly.
10392
* This also means that they can be safely used in collections relying on the equals method,

sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
2020
import scala.reflect.runtime.universe.typeTag
2121

2222
import org.apache.spark.annotation.Stable
23-
import org.apache.spark.sql.catalyst.util.TypeUtils
23+
import org.apache.spark.unsafe.types.ByteArray
2424

2525
/**
2626
* The data type representing `Array[Byte]` values.
@@ -37,7 +37,7 @@ class BinaryType private() extends AtomicType {
3737
@transient private[sql] lazy val tag = typeTag[InternalType]
3838

3939
private[sql] val ordering =
40-
(x: Array[Byte], y: Array[Byte]) => TypeUtils.compareBinary(x, y)
40+
(x: Array[Byte], y: Array[Byte]) => ByteArray.compareBinary(x, y)
4141

4242
/**
4343
* The default size of a value of the BinaryType is 100 bytes.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,4 @@ class TypeUtilsSuite extends SparkFunSuite {
4343
typeCheckPass(ArrayType(StringType, containsNull = true) ::
4444
ArrayType(StringType, containsNull = false) :: Nil)
4545
}
46-
47-
test("compareBinary") {
48-
val x1 = Array[Byte]()
49-
val y1 = Array(1, 2, 3).map(_.toByte)
50-
assert(TypeUtils.compareBinary(x1, y1) < 0)
51-
52-
val x2 = Array(200, 100).map(_.toByte)
53-
val y2 = Array(100, 100).map(_.toByte)
54-
assert(TypeUtils.compareBinary(x2, y2) > 0)
55-
56-
val x3 = Array(100, 200, 12).map(_.toByte)
57-
val y3 = Array(100, 200).map(_.toByte)
58-
assert(TypeUtils.compareBinary(x3, y3) > 0)
59-
60-
val x4 = Array(100, 200).map(_.toByte)
61-
val y4 = Array(100, 200).map(_.toByte)
62-
assert(TypeUtils.compareBinary(x4, y4) == 0)
63-
}
6446
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
================================================================================================
2+
byte array comparisons
3+
================================================================================================
4+
5+
OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.8.0-1042-azure
6+
Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
7+
Byte Array compareTo: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
8+
------------------------------------------------------------------------------------------------------------------------
9+
2-7 byte 501 514 14 130.9 7.6 1.0X
10+
8-16 byte 976 993 10 67.1 14.9 0.5X
11+
16-32 byte 985 995 6 66.5 15.0 0.5X
12+
512-1024 byte 1260 1282 13 52.0 19.2 0.4X
13+
512 byte slow 3114 3193 46 21.0 47.5 0.2X
14+
2-7 byte 572 578 7 114.5 8.7 0.9X
15+
16+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
================================================================================================
2+
byte array comparisons
3+
================================================================================================
4+
5+
OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.8.0-1042-azure
6+
Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
7+
Byte Array compareTo: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
8+
------------------------------------------------------------------------------------------------------------------------
9+
2-7 byte 407 418 9 161.1 6.2 1.0X
10+
8-16 byte 867 919 30 75.6 13.2 0.5X
11+
16-32 byte 882 916 23 74.3 13.5 0.5X
12+
512-1024 byte 1123 1167 31 58.4 17.1 0.4X
13+
512 byte slow 4054 4611 506 16.2 61.9 0.1X
14+
2-7 byte 430 450 16 152.4 6.6 0.9X
15+
16+
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.benchmark
19+
20+
import scala.util.Random
21+
22+
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
23+
import org.apache.spark.unsafe.types.ByteArray
24+
25+
/**
26+
* Benchmark to measure performance for byte array comparisons.
27+
* {{{
28+
* To run this benchmark:
29+
* 1. without sbt:
30+
* bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
31+
* 2. build/sbt "sql/test:runMain <this class>"
32+
* 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
33+
* Results will be written to "benchmarks/<this class>-results.txt".
34+
* }}}
35+
*/
36+
object ByteArrayBenchmark extends BenchmarkBase {
37+
38+
def byteArrayComparisons(iters: Long): Unit = {
39+
val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
40+
val random = new Random(0)
41+
def randomBytes(min: Int, max: Int): Array[Byte] = {
42+
val len = random.nextInt(max - min) + min
43+
val bytes = new Array[Byte](len)
44+
var i = 0
45+
while (i < len) {
46+
bytes(i) = chars.charAt(random.nextInt(chars.length())).toByte
47+
i += 1
48+
}
49+
bytes
50+
}
51+
52+
val count = 16 * 1000
53+
val dataTiny = Seq.fill(count)(randomBytes(2, 7)).toArray
54+
val dataSmall = Seq.fill(count)(randomBytes(8, 16)).toArray
55+
val dataMedium = Seq.fill(count)(randomBytes(16, 32)).toArray
56+
val dataLarge = Seq.fill(count)(randomBytes(512, 1024)).toArray
57+
val dataLargeSlow = Seq.fill(count)(
58+
Array.tabulate(512) {i => if (i < 511) 0.toByte else 1.toByte}).toArray
59+
60+
def compareBinary(data: Array[Array[Byte]]) = { _: Int =>
61+
var sum = 0L
62+
for (_ <- 0L until iters) {
63+
var i = 0
64+
while (i < count) {
65+
sum += ByteArray.compareBinary(data(i), data((i + 1) % count))
66+
i += 1
67+
}
68+
}
69+
}
70+
71+
val benchmark = new Benchmark("Byte Array compareTo", count * iters, 25, output = output)
72+
benchmark.addCase("2-7 byte")(compareBinary(dataTiny))
73+
benchmark.addCase("8-16 byte")(compareBinary(dataSmall))
74+
benchmark.addCase("16-32 byte")(compareBinary(dataMedium))
75+
benchmark.addCase("512-1024 byte")(compareBinary(dataLarge))
76+
benchmark.addCase("512 byte slow")(compareBinary(dataLargeSlow))
77+
benchmark.addCase("2-7 byte")(compareBinary(dataTiny))
78+
benchmark.run()
79+
}
80+
81+
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
82+
runBenchmark("byte array comparisons") {
83+
byteArrayComparisons(1024 * 4)
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)