Skip to content

Commit f3baf08

Browse files
thepinetreedongjoon-hyun
authored andcommitted
[SPARK-43393][SQL][3.5] Address sequence expression overflow bug
### What changes were proposed in this pull request? Spark has a (long-standing) overflow bug in the `sequence` expression. Consider the following operations: ``` spark.sql("CREATE TABLE foo (l LONG);") spark.sql(s"INSERT INTO foo VALUES (${Long.MaxValue});") spark.sql("SELECT sequence(0, l) FROM foo;").collect() ``` The result of these operations will be: ``` Array[org.apache.spark.sql.Row] = Array([WrappedArray()]) ``` an unintended consequence of overflow. The sequence is applied to values `0` and `Long.MaxValue` with a step size of `1` which uses a length computation defined [here](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3451). In this calculation, with `start = 0`, `stop = Long.MaxValue`, and `step = 1`, the calculated `len` overflows to `Long.MinValue`. The computation, in binary looks like: ``` 0111111111111111111111111111111111111111111111111111111111111111 - 0000000000000000000000000000000000000000000000000000000000000000 ------------------------------------------------------------------ 0111111111111111111111111111111111111111111111111111111111111111 / 0000000000000000000000000000000000000000000000000000000000000001 ------------------------------------------------------------------ 0111111111111111111111111111111111111111111111111111111111111111 + 0000000000000000000000000000000000000000000000000000000000000001 ------------------------------------------------------------------ 1000000000000000000000000000000000000000000000000000000000000000 ``` The following [check](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3454) passes as the negative `Long.MinValue` is still `<= MAX_ROUNDED_ARRAY_LENGTH`. The following cast to `toInt` uses this representation and [truncates the upper bits](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3457) resulting in an empty length of `0`. Other overflows are similarly problematic. This PR addresses the issue by checking numeric operations in the length computation for overflow. ### Why are the changes needed? There is a correctness bug from overflow in the `sequence` expression. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests added in `CollectionExpressionsSuite.scala`. Closes #43820 from thepinetree/spark-sequence-overflow-3.5. Authored-by: Deepayan Patra <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 9e492b7 commit f3baf08

File tree

2 files changed

+71
-20
lines changed

2 files changed

+71
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Comparator
2222
import scala.collection.mutable
2323
import scala.reflect.ClassTag
2424

25+
import org.apache.spark.SparkException.internalError
2526
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
2728
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -40,7 +41,6 @@ import org.apache.spark.sql.types._
4041
import org.apache.spark.sql.util.SQLOpenHashSet
4142
import org.apache.spark.unsafe.UTF8StringBuilder
4243
import org.apache.spark.unsafe.array.ByteArrayMethods
43-
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
4444
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
4545

4646
/**
@@ -3080,6 +3080,34 @@ case class Sequence(
30803080
}
30813081

30823082
object Sequence {
3083+
private def prettyName: String = "sequence"
3084+
3085+
def sequenceLength(start: Long, stop: Long, step: Long): Int = {
3086+
try {
3087+
val delta = Math.subtractExact(stop, start)
3088+
if (delta == Long.MinValue && step == -1L) {
3089+
// We must special-case division of Long.MinValue by -1 to catch potential unchecked
3090+
// overflow in next operation. Division does not have a builtin overflow check. We
3091+
// previously special-case div-by-zero.
3092+
throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
3093+
}
3094+
val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
3095+
if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3096+
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(len)
3097+
}
3098+
len.toInt
3099+
} catch {
3100+
// We handle overflows in the previous try block by raising an appropriate exception.
3101+
case _: ArithmeticException =>
3102+
val safeLen =
3103+
BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
3104+
if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3105+
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(safeLen)
3106+
}
3107+
throw internalError("Unreachable code reached.")
3108+
case e: Exception => throw e
3109+
}
3110+
}
30833111

30843112
private type LessThanOrEqualFn = (Any, Any) => Boolean
30853113

@@ -3451,13 +3479,7 @@ object Sequence {
34513479
|| (estimatedStep == num.zero && start == stop),
34523480
s"Illegal sequence boundaries: $start to $stop by $step")
34533481

3454-
val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
3455-
3456-
require(
3457-
len <= MAX_ROUNDED_ARRAY_LENGTH,
3458-
s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
3459-
3460-
len.toInt
3482+
sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
34613483
}
34623484

34633485
private def genSequenceLengthCode(
@@ -3467,20 +3489,15 @@ object Sequence {
34673489
step: String,
34683490
estimatedStep: String,
34693491
len: String): String = {
3470-
val longLen = ctx.freshName("longLen")
3492+
val calcFn = classOf[Sequence].getName + ".sequenceLength"
34713493
s"""
34723494
|if (!(($estimatedStep > 0 && $start <= $stop) ||
34733495
| ($estimatedStep < 0 && $start >= $stop) ||
34743496
| ($estimatedStep == 0 && $start == $stop))) {
34753497
| throw new IllegalArgumentException(
34763498
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
34773499
|}
3478-
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep;
3479-
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
3480-
| throw new IllegalArgumentException(
3481-
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
3482-
|}
3483-
|int $len = (int) $longLen;
3500+
|int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
34843501
""".stripMargin
34853502
}
34863503
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
3434
import org.apache.spark.sql.catalyst.util.IntervalUtils._
3535
import org.apache.spark.sql.internal.SQLConf
3636
import org.apache.spark.sql.types._
37-
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
37+
import org.apache.spark.unsafe.array.ByteArrayMethods
3838
import org.apache.spark.unsafe.types.UTF8String
3939

4040
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -769,10 +769,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
769769

770770
// test sequence boundaries checking
771771

772-
checkExceptionInExpression[IllegalArgumentException](
773-
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
774-
EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
775-
776772
checkExceptionInExpression[IllegalArgumentException](
777773
new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0")
778774
checkExceptionInExpression[IllegalArgumentException](
@@ -782,6 +778,44 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
782778
checkExceptionInExpression[IllegalArgumentException](
783779
new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1")
784780

781+
// SPARK-43393: test Sequence overflow checking
782+
checkErrorInExpression[SparkRuntimeException](
783+
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
784+
errorClass = "_LEGACY_ERROR_TEMP_2161",
785+
parameters = Map(
786+
"count" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString,
787+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
788+
checkErrorInExpression[SparkRuntimeException](
789+
new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
790+
errorClass = "_LEGACY_ERROR_TEMP_2161",
791+
parameters = Map(
792+
"count" -> (BigInt(Long.MaxValue) + 1).toString,
793+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
794+
checkErrorInExpression[SparkRuntimeException](
795+
new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
796+
errorClass = "_LEGACY_ERROR_TEMP_2161",
797+
parameters = Map(
798+
"count" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
799+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
800+
checkErrorInExpression[SparkRuntimeException](
801+
new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)),
802+
errorClass = "_LEGACY_ERROR_TEMP_2161",
803+
parameters = Map(
804+
"count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
805+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
806+
checkErrorInExpression[SparkRuntimeException](
807+
new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)),
808+
errorClass = "_LEGACY_ERROR_TEMP_2161",
809+
parameters = Map(
810+
"count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
811+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
812+
checkErrorInExpression[SparkRuntimeException](
813+
new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
814+
errorClass = "_LEGACY_ERROR_TEMP_2161",
815+
parameters = Map(
816+
"count" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString,
817+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
818+
785819
// test sequence with one element (zero step or equal start and stop)
786820

787821
checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))

0 commit comments

Comments
 (0)