Skip to content

Commit 5f9a7fe

Browse files
AngersZhuuuucloud-fan
authored andcommitted
[SPARK-33428][SQL] Conv UDF use BigInt to avoid Long value overflow
### What changes were proposed in this pull request? Use Long value store encode value will overflow and return unexpected result, use BigInt to replace Long value and make logical more simple. ### Why are the changes needed? Fix value overflow issue ### Does this PR introduce _any_ user-facing change? People can sue `conf` function to convert value big then LONG.MAX_VALUE ### How was this patch tested? Added UT #### BenchMark ``` /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql.execution.benchmark import scala.util.Random import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.functions._ object ConvFuncBenchMark extends SqlBasedBenchmark { val charset = Array[String]("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z") def constructString(from: Int, length: Int): String = { val chars = charset.slice(0, from) (0 to length).map(x => { val v = Random.nextInt(from) chars(v) }).mkString("") } private def doBenchmark(cardinality: Long, length: Int, from: Int, toBase: Int): Unit = { spark.range(cardinality) .withColumn("str", lit(constructString(from, length))) .select(conv(col("str"), from, toBase)) .noop() } /** * Main process of the whole benchmark. * Implementations of this method are supposed to use the wrapper method `runBenchmark` * for each benchmark scenario. */ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val N = 1000000L val benchmark = new Benchmark("conv", N, output = output) benchmark.addCase("length 10 from 2 to 16") { _ => doBenchmark(N, 10, 2, 16) } benchmark.addCase("length 10 from 2 to 10") { _ => doBenchmark(N, 10, 2, 10) } benchmark.addCase("length 10 from 10 to 16") { _ => doBenchmark(N, 10, 10, 16) } benchmark.addCase("length 10 from 10 to 36") { _ => doBenchmark(N, 10, 10, 36) } benchmark.addCase("length 10 from 16 to 10") { _ => doBenchmark(N, 10, 10, 10) } benchmark.addCase("length 10 from 16 to 36") { _ => doBenchmark(N, 10, 16, 36) } benchmark.addCase("length 10 from 36 to 10") { _ => doBenchmark(N, 10, 36, 10) } benchmark.addCase("length 10 from 36 to 16") { _ => doBenchmark(N, 10, 36, 16) } // benchmark.addCase("length 20 from 10 to 16") { _ => doBenchmark(N, 20, 10, 16) } benchmark.addCase("length 20 from 10 to 36") { _ => doBenchmark(N, 20, 10, 36) } benchmark.addCase("length 30 from 10 to 16") { _ => doBenchmark(N, 30, 10, 16) } benchmark.addCase("length 30 from 10 to 36") { _ => doBenchmark(N, 30, 10, 36) } // benchmark.addCase("length 20 from 16 to 10") { _ => doBenchmark(N, 20, 16, 10) } benchmark.addCase("length 20 from 16 to 36") { _ => doBenchmark(N, 20, 16, 36) } benchmark.addCase("length 30 from 16 to 10") { _ => doBenchmark(N, 30, 16, 10) } benchmark.addCase("length 30 from 16 to 36") { _ => doBenchmark(N, 30, 16, 36) } benchmark.run() } } ``` Result with patch : ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_191-b12 on Mac OS X 10.14.6 Intel(R) Core(TM) i5-8259U CPU 2.30GHz conv: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ length 10 from 2 to 16 54 73 18 18.7 53.6 1.0X length 10 from 2 to 10 43 47 5 23.5 42.5 1.3X length 10 from 10 to 16 39 47 12 25.5 39.2 1.4X length 10 from 10 to 36 38 42 3 26.5 37.7 1.4X length 10 from 16 to 10 39 41 3 25.7 38.9 1.4X length 10 from 16 to 36 36 41 4 27.6 36.3 1.5X length 10 from 36 to 10 38 40 2 26.3 38.0 1.4X length 10 from 36 to 16 37 39 2 26.8 37.2 1.4X length 20 from 10 to 16 36 39 2 27.4 36.5 1.5X length 20 from 10 to 36 37 39 2 27.2 36.8 1.5X length 30 from 10 to 16 37 39 2 27.0 37.0 1.4X length 30 from 10 to 36 36 38 2 27.5 36.3 1.5X length 20 from 16 to 10 35 38 2 28.3 35.4 1.5X length 20 from 16 to 36 34 38 3 29.2 34.3 1.6X length 30 from 16 to 10 38 40 2 26.3 38.1 1.4X length 30 from 16 to 36 37 38 1 27.2 36.8 1.5X ``` Result without patch: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_191-b12 on Mac OS X 10.14.6 Intel(R) Core(TM) i5-8259U CPU 2.30GHz conv: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ length 10 from 2 to 16 66 101 29 15.1 66.1 1.0X length 10 from 2 to 10 50 55 5 20.2 49.5 1.3X length 10 from 10 to 16 46 51 5 21.8 45.9 1.4X length 10 from 10 to 36 43 48 4 23.4 42.7 1.5X length 10 from 16 to 10 44 47 4 22.9 43.7 1.5X length 10 from 16 to 36 40 44 2 24.7 40.5 1.6X length 10 from 36 to 10 40 44 4 25.0 40.1 1.6X length 10 from 36 to 16 41 43 2 24.3 41.2 1.6X length 20 from 10 to 16 39 41 2 25.7 38.9 1.7X length 20 from 10 to 36 40 42 2 24.9 40.2 1.6X length 30 from 10 to 16 39 40 1 25.9 38.6 1.7X length 30 from 10 to 36 40 41 1 25.0 40.0 1.7X length 20 from 16 to 10 40 41 1 25.1 39.8 1.7X length 20 from 16 to 36 40 42 2 25.2 39.7 1.7X length 30 from 16 to 10 39 42 2 25.6 39.0 1.7X length 30 from 16 to 36 39 40 2 25.7 38.8 1.7X ``` Closes apache#30350 from AngersZhuuuu/SPARK-33428. Authored-by: angerszhu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent bf2c88c commit 5f9a7fe

File tree

5 files changed

+23
-57
lines changed

5 files changed

+23
-57
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,64 +21,37 @@ import org.apache.spark.unsafe.types.UTF8String
2121

2222
object NumberConverter {
2323

24-
/**
25-
* Divide x by m as if x is an unsigned 64-bit integer. Examples:
26-
* unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
27-
* unsignedLongDiv(0, 5) == 0
28-
*
29-
* @param x is treated as unsigned
30-
* @param m is treated as signed
31-
*/
32-
private def unsignedLongDiv(x: Long, m: Int): Long = {
33-
if (x >= 0) {
34-
x / m
35-
} else {
36-
// Let uval be the value of the unsigned long with the same bits as x
37-
// Two's complement => x = uval - 2*MAX - 2
38-
// => uval = x + 2*MAX + 2
39-
// Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
40-
x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m
41-
}
42-
}
43-
4424
/**
4525
* Decode v into value[].
4626
*
47-
* @param v is treated as an unsigned 64-bit integer
27+
* @param v is treated as an BigInt
4828
* @param radix must be between MIN_RADIX and MAX_RADIX
4929
*/
50-
private def decode(v: Long, radix: Int, value: Array[Byte]): Unit = {
30+
private def decode(v: BigInt, radix: Int, value: Array[Byte]): Unit = {
5131
var tmpV = v
5232
java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
5333
var i = value.length - 1
5434
while (tmpV != 0) {
55-
val q = unsignedLongDiv(tmpV, radix)
56-
value(i) = (tmpV - q * radix).asInstanceOf[Byte]
35+
val q = tmpV / radix
36+
value(i) = (tmpV - q * radix).byteValue
5737
tmpV = q
5838
i -= 1
5939
}
6040
}
6141

6242
/**
63-
* Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
64-
* negative digit is found, ignore the suffix starting there.
43+
* Convert value[] into a BigInt. If a negative digit is found,
44+
* ignore the suffix starting there.
6545
*
6646
* @param radix must be between MIN_RADIX and MAX_RADIX
6747
* @param fromPos is the first element that should be considered
6848
* @return the result should be treated as an unsigned 64-bit integer.
6949
*/
70-
private def encode(radix: Int, fromPos: Int, value: Array[Byte]): Long = {
71-
var v: Long = 0L
72-
val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
50+
private def encode(radix: Int, fromPos: Int, value: Array[Byte]): BigInt = {
51+
var v: BigInt = BigInt(0)
7352
var i = fromPos
7453
while (i < value.length && value(i) >= 0) {
75-
if (v >= bound) {
76-
// Check for overflow
77-
if (unsignedLongDiv(-1 - value(i), radix) < v) {
78-
return -1
79-
}
80-
}
81-
v = v * radix + value(i)
54+
v = (v * radix) + BigInt(value(i))
8255
i += 1
8356
}
8457
v
@@ -129,7 +102,7 @@ object NumberConverter {
129102
return null
130103
}
131104

132-
var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
105+
val (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
133106

134107
// Copy the digits in the right side of the array
135108
val temp = new Array[Byte](64)
@@ -140,19 +113,8 @@ object NumberConverter {
140113
}
141114
char2byte(fromBase, temp.length - n.length + first, temp)
142115

143-
// Do the conversion by going through a 64 bit integer
144-
var v = encode(fromBase, temp.length - n.length + first, temp)
145-
if (negative && toBase > 0) {
146-
if (v < 0) {
147-
v = -1
148-
} else {
149-
v = -v
150-
}
151-
}
152-
if (toBase < 0 && v < 0) {
153-
v = -v
154-
negative = true
155-
}
116+
// Do the conversion by going through a BigInt
117+
val v: BigInt = encode(fromBase, temp.length - n.length + first, temp)
156118
decode(v, Math.abs(toBase), temp)
157119

158120
// Find the first non-zero digit or the last digits if all are zero.
@@ -163,7 +125,7 @@ object NumberConverter {
163125
byte2char(Math.abs(toBase), firstNonZeroPos, temp)
164126

165127
var resultStartPos = firstNonZeroPos
166-
if (negative && toBase < 0) {
128+
if (negative) {
167129
resultStartPos = firstNonZeroPos - 1
168130
temp(resultStartPos) = '-'
169131
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
158158
test("conv") {
159159
checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
160160
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
161-
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
161+
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "-F")
162162
checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
163163
checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
164164
checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
@@ -168,10 +168,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
168168
checkEvaluation(
169169
Conv(Literal(""), Literal(10), Literal(16)), null)
170170
checkEvaluation(
171-
Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
171+
Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "12DDAC15F246BAF8C0D551AC7")
172172
// If there is an invalid digit in the number, the longest valid prefix should be converted.
173173
checkEvaluation(
174174
Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
175+
checkEvaluation(Conv(Literal("c8dcdfb41711fc9a1f17928001d7fd61"), Literal(16), Literal(10)),
176+
"266992441711411603393340504520074460513")
175177
}
176178

177179
test("e") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ class NumberConverterSuite extends SparkFunSuite {
3434
test("convert") {
3535
checkConv("3", 10, 2, "11")
3636
checkConv("-15", 10, -16, "-F")
37-
checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
37+
checkConv("-15", 10, 16, "-F")
3838
checkConv("big", 36, 16, "3A48")
39-
checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
39+
checkConv("9223372036854775807", 36, 16, "12DDAC15F246BAF8C0D551AC7")
4040
checkConv("11abc", 10, 16, "B")
4141
}
4242

sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
200200
checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4"))
201201
checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16"))
202202
checkAnswer(
203-
df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow
203+
df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("12DDAC15F246BAF8C0D551AC7"))
204204
}
205205

206206
test("floor") {

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
525525
"udf_xpath_short",
526526
"udf_xpath_string",
527527

528+
// [SPARK-33428][SQL] CONV UDF use BigInt to avoid Long value overflow
529+
"udf_conv",
530+
528531
// These tests DROP TABLE that don't exist (but do not specify IF EXISTS)
529532
"alter_rename_partition1",
530533
"date_1",
@@ -1003,7 +1006,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
10031006
"udf_concat_insert1",
10041007
"udf_concat_insert2",
10051008
"udf_concat_ws",
1006-
"udf_conv",
10071009
"udf_cos",
10081010
"udf_count",
10091011
"udf_date_add",

0 commit comments

Comments
 (0)