Skip to content

Commit 59ddb99

Browse files
committed
Code review
1 parent e17bb46 commit 59ddb99

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,31 +1464,27 @@ case class ArrayContains(left: Expression, right: Expression)
14641464
nullSafeCodeGen(ctx, ev, (arr, value) => {
14651465
val i = ctx.freshName("i")
14661466
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
1467-
def checkAndSetIsNullCode(body: String) = if (nullable) {
1467+
val loopBodyCode = if (nullable) {
14681468
s"""
14691469
|if ($arr.isNullAt($i)) {
1470-
| ${ev.isNull} = true;
1471-
|} else {
1472-
| $body
1470+
| ${ev.isNull} = true;
1471+
|} else if (${ctx.genEqual(right.dataType, value, getValue)}) {
1472+
| ${ev.isNull} = false;
1473+
| ${ev.value} = true;
1474+
| break;
14731475
|}
14741476
""".stripMargin
14751477
} else {
1476-
body
1477-
}
1478-
val unsetIsNullCode = if (nullable) s"${ev.isNull} = false;" else ""
1479-
val code = checkAndSetIsNullCode(
14801478
s"""
14811479
|if (${ctx.genEqual(right.dataType, value, getValue)}) {
1482-
| $unsetIsNullCode
14831480
| ${ev.value} = true;
14841481
| break;
14851482
|}
14861483
""".stripMargin
1487-
)
1488-
1484+
}
14891485
s"""
14901486
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
1491-
| $code
1487+
| $loopBodyCode
14921488
|}
14931489
""".stripMargin
14941490
})

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
384384
val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq(
385385
StructField("a", IntegerType, true)))))
386386
// Explicitly mark the array type not nullable (spark-25308)
387-
val a5 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, false))
387+
val a5 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
388388

389389
checkEvaluation(ArrayContains(a0, Literal(1)), true)
390390
checkEvaluation(ArrayContains(a0, Literal(0)), false)

0 commit comments

Comments
 (0)