Skip to content

Commit ab01ba7

Browse files
bogdanrdcgatorsmile
authored andcommitted
[SPARK-23316][SQL] AnalysisException after max iteration reached for IN query
## What changes were proposed in this pull request? Added flag ignoreNullability to DataType.equalsStructurally. The previous semantic is for ignoreNullability=false. When ignoreNullability=true equalsStructurally ignores nullability of contained types (map key types, value types, array element types, structure field types). In.checkInputTypes calls equalsStructurally to check if the children types match. They should match regardless of nullability (which is just a hint), so it is now called with ignoreNullability=true. ## How was this patch tested? New test in SubquerySuite Author: Bogdan Raducanu <[email protected]> Closes #20548 from bogdanrdc/SPARK-23316. (cherry picked from commit 05d0512) Signed-off-by: gatorsmile <[email protected]>
1 parent dbb1b39 commit ab01ba7

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
157157
require(list != null, "list should not be null")
158158

159159
override def checkInputDataTypes(): TypeCheckResult = {
160-
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
160+
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType,
161+
ignoreNullability = true))
161162
if (mismatchOpt.isDefined) {
162163
list match {
163164
case ListQuery(_, _, _, childOutputs) :: Nil =>

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,25 +295,31 @@ object DataType {
295295
}
296296

297297
/**
298-
* Returns true if the two data types share the same "shape", i.e. the types (including
299-
* nullability) are the same, but the field names don't need to be the same.
298+
* Returns true if the two data types share the same "shape", i.e. the types
299+
* are the same, but the field names don't need to be the same.
300+
*
301+
* @param ignoreNullability whether to ignore nullability when comparing the types
300302
*/
301-
def equalsStructurally(from: DataType, to: DataType): Boolean = {
303+
def equalsStructurally(
304+
from: DataType,
305+
to: DataType,
306+
ignoreNullability: Boolean = false): Boolean = {
302307
(from, to) match {
303308
case (left: ArrayType, right: ArrayType) =>
304309
equalsStructurally(left.elementType, right.elementType) &&
305-
left.containsNull == right.containsNull
310+
(ignoreNullability || left.containsNull == right.containsNull)
306311

307312
case (left: MapType, right: MapType) =>
308313
equalsStructurally(left.keyType, right.keyType) &&
309314
equalsStructurally(left.valueType, right.valueType) &&
310-
left.valueContainsNull == right.valueContainsNull
315+
(ignoreNullability || left.valueContainsNull == right.valueContainsNull)
311316

312317
case (StructType(fromFields), StructType(toFields)) =>
313318
fromFields.length == toFields.length &&
314319
fromFields.zip(toFields)
315320
.forall { case (l, r) =>
316-
equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable
321+
equalsStructurally(l.dataType, r.dataType) &&
322+
(ignoreNullability || l.nullable == r.nullable)
317323
}
318324

319325
case (fromDataType, toDataType) => fromDataType == toDataType

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,4 +950,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
950950
assert(join.duplicateResolved)
951951
assert(optimizedPlan.resolved)
952952
}
953+
954+
test("SPARK-23316: AnalysisException after max iteration reached for IN query") {
955+
// before the fix this would throw AnalysisException
956+
spark.range(10).where("(id,id) in (select id, null from range(3))").count
957+
}
953958
}

0 commit comments

Comments
 (0)