diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 39442c3dd2aa..2ea7afdfc41a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -75,6 +75,42 @@ static long getPrefix(Object base, long offset, int numBytes) { return (IS_LITTLE_ENDIAN ? java.lang.Long.reverseBytes(p) : p) & ~mask; } + public static int compareBinary(byte[] leftBase, byte[] rightBase) { + return compareBinary(leftBase, Platform.BYTE_ARRAY_OFFSET, leftBase.length, + rightBase, Platform.BYTE_ARRAY_OFFSET, rightBase.length); + } + + static int compareBinary( + Object leftBase, + long leftOffset, + int leftNumBytes, + Object rightBase, + long rightOffset, + int rightNumBytes) { + int len = Math.min(leftNumBytes, rightNumBytes); + int wordMax = (len / 8) * 8; + for (int i = 0; i < wordMax; i += 8) { + long left = Platform.getLong(leftBase, leftOffset + i); + long right = Platform.getLong(rightBase, rightOffset + i); + if (left != right) { + if (IS_LITTLE_ENDIAN) { + return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); + } else { + return Long.compareUnsigned(left, right); + } + } + } + for (int i = wordMax; i < len; i++) { + // Both UTF-8 and byte array should be compared as unsigned int. + int res = (Platform.getByte(leftBase, leftOffset + i) & 0xFF) - + (Platform.getByte(rightBase, rightOffset + i) & 0xFF); + if (res != 0) { + return res; + } + } + return leftNumBytes - rightNumBytes; + } + public static byte[] subStringSQL(byte[] bytes, int pos, int len) { // This pos calculation is according to UTF8String#subStringSQL if (pos > bytes.length) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 6c3adf2c798c..c47b90d4be6a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1353,29 +1353,8 @@ public UTF8String copy() { @Override public int compareTo(@Nonnull final UTF8String other) { - int len = Math.min(numBytes, other.numBytes); - int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; - for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); - if (left != right) { - if (IS_LITTLE_ENDIAN) { - return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); - } else { - return Long.compareUnsigned(left, right); - } - } - } - for (int i = wordMax; i < len; i++) { - // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); - if (res != 0) { - return res; - } - } - return numBytes - other.numBytes; + return ByteArray.compareBinary( + base, offset, numBytes, other.base, other.offset, other.numBytes); } public int compare(final UTF8String other) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java index 703610dfde44..67de4359875c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java @@ -48,4 +48,23 @@ public void testGetPrefix() { Assert.assertEquals(result, expected); } } + + @Test + public void testCompareBinary() { + byte[] x1 = new byte[0]; + byte[] y1 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + assert(ByteArray.compareBinary(x1, y1) < 0); + + byte[] x2 = new byte[]{(byte) 200, (byte) 100}; + byte[] y2 = new byte[]{(byte) 100, (byte) 100}; + assert(ByteArray.compareBinary(x2, y2) > 0); + + byte[] x3 = new byte[]{(byte) 100, (byte) 200, (byte) 12}; + byte[] y3 = new byte[]{(byte) 100, (byte) 200}; + assert(ByteArray.compareBinary(x3, y3) > 0); + + byte[] x4 = new byte[]{(byte) 100, (byte) 200}; + byte[] y4 = new byte[]{(byte) 100, (byte) 200}; + assert(ByteArray.compareBinary(x4, y4) == 0); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7f2c1c652dc8..026696aaec59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -650,7 +650,7 @@ class CodegenContext extends Logging { s"$clsName.compareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" - case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)" case NullType => "0" case array: ArrayType => val elementType = array.elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1a8de4c36e0b..cba3a9a9763e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -87,17 +87,6 @@ object TypeUtils { } } - def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { - val limit = if (x.length <= y.length) x.length else y.length - var i = 0 - while (i < limit) { - val res = (x(i) & 0xff) - (y(i) & 0xff) - if (res != 0) return res - i += 1 - } - x.length - y.length - } - /** * Returns true if the equals method of the elements of the data type is implemented properly. * This also means that they can be safely used in collections relying on the equals method, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index dddf874b9c6c..c3fa54c1767d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.unsafe.types.ByteArray /** * The data type representing `Array[Byte]` values. @@ -37,7 +37,7 @@ class BinaryType private() extends AtomicType { @transient private[sql] lazy val tag = typeTag[InternalType] private[sql] val ordering = - (x: Array[Byte], y: Array[Byte]) => TypeUtils.compareBinary(x, y) + (x: Array[Byte], y: Array[Byte]) => ByteArray.compareBinary(x, y) /** * The default size of a value of the BinaryType is 100 bytes. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala index d6d1e418d74e..bc6852ca7e1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala @@ -43,22 +43,4 @@ class TypeUtilsSuite extends SparkFunSuite { typeCheckPass(ArrayType(StringType, containsNull = true) :: ArrayType(StringType, containsNull = false) :: Nil) } - - test("compareBinary") { - val x1 = Array[Byte]() - val y1 = Array(1, 2, 3).map(_.toByte) - assert(TypeUtils.compareBinary(x1, y1) < 0) - - val x2 = Array(200, 100).map(_.toByte) - val y2 = Array(100, 100).map(_.toByte) - assert(TypeUtils.compareBinary(x2, y2) > 0) - - val x3 = Array(100, 200, 12).map(_.toByte) - val y3 = Array(100, 200).map(_.toByte) - assert(TypeUtils.compareBinary(x3, y3) > 0) - - val x4 = Array(100, 200).map(_.toByte) - val y4 = Array(100, 200).map(_.toByte) - assert(TypeUtils.compareBinary(x4, y4) == 0) - } } diff --git a/sql/core/benchmarks/ByteArrayBenchmark-jdk11-results.txt b/sql/core/benchmarks/ByteArrayBenchmark-jdk11-results.txt new file mode 100644 index 000000000000..c1a03a5e25c4 --- /dev/null +++ b/sql/core/benchmarks/ByteArrayBenchmark-jdk11-results.txt @@ -0,0 +1,16 @@ +================================================================================================ +byte array comparisons +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.8.0-1042-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +Byte Array compareTo: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +2-7 byte 501 514 14 130.9 7.6 1.0X +8-16 byte 976 993 10 67.1 14.9 0.5X +16-32 byte 985 995 6 66.5 15.0 0.5X +512-1024 byte 1260 1282 13 52.0 19.2 0.4X +512 byte slow 3114 3193 46 21.0 47.5 0.2X +2-7 byte 572 578 7 114.5 8.7 0.9X + + diff --git a/sql/core/benchmarks/ByteArrayBenchmark-results.txt b/sql/core/benchmarks/ByteArrayBenchmark-results.txt new file mode 100644 index 000000000000..5c71a782ad43 --- /dev/null +++ b/sql/core/benchmarks/ByteArrayBenchmark-results.txt @@ -0,0 +1,16 @@ +================================================================================================ +byte array comparisons +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.8.0-1042-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +Byte Array compareTo: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +2-7 byte 407 418 9 161.1 6.2 1.0X +8-16 byte 867 919 30 75.6 13.2 0.5X +16-32 byte 882 916 23 74.3 13.5 0.5X +512-1024 byte 1123 1167 31 58.4 17.1 0.4X +512 byte slow 4054 4611 506 16.2 61.9 0.1X +2-7 byte 430 450 16 152.4 6.6 0.9X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala new file mode 100644 index 000000000000..f8b1e27d3095 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.util.Random + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.unsafe.types.ByteArray + +/** + * Benchmark to measure performance for byte array comparisons. + * {{{ + * To run this benchmark: + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object ByteArrayBenchmark extends BenchmarkBase { + + def byteArrayComparisons(iters: Long): Unit = { + val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + val random = new Random(0) + def randomBytes(min: Int, max: Int): Array[Byte] = { + val len = random.nextInt(max - min) + min + val bytes = new Array[Byte](len) + var i = 0 + while (i < len) { + bytes(i) = chars.charAt(random.nextInt(chars.length())).toByte + i += 1 + } + bytes + } + + val count = 16 * 1000 + val dataTiny = Seq.fill(count)(randomBytes(2, 7)).toArray + val dataSmall = Seq.fill(count)(randomBytes(8, 16)).toArray + val dataMedium = Seq.fill(count)(randomBytes(16, 32)).toArray + val dataLarge = Seq.fill(count)(randomBytes(512, 1024)).toArray + val dataLargeSlow = Seq.fill(count)( + Array.tabulate(512) {i => if (i < 511) 0.toByte else 1.toByte}).toArray + + def compareBinary(data: Array[Array[Byte]]) = { _: Int => + var sum = 0L + for (_ <- 0L until iters) { + var i = 0 + while (i < count) { + sum += ByteArray.compareBinary(data(i), data((i + 1) % count)) + i += 1 + } + } + } + + val benchmark = new Benchmark("Byte Array compareTo", count * iters, 25, output = output) + benchmark.addCase("2-7 byte")(compareBinary(dataTiny)) + benchmark.addCase("8-16 byte")(compareBinary(dataSmall)) + benchmark.addCase("16-32 byte")(compareBinary(dataMedium)) + benchmark.addCase("512-1024 byte")(compareBinary(dataLarge)) + benchmark.addCase("512 byte slow")(compareBinary(dataLargeSlow)) + benchmark.addCase("2-7 byte")(compareBinary(dataTiny)) + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("byte array comparisons") { + byteArrayComparisons(1024 * 4) + } + } +}