Skip to content

Commit 183d4cb

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-21759][SQL] In.checkInputDataTypes should not wrongly report unresolved plans for IN correlated subquery
## What changes were proposed in this pull request? With the check for structural integrity proposed in SPARK-21726, it is found that the optimization rule `PullupCorrelatedPredicates` can produce unresolved plans. For a correlated IN query looks like: SELECT t1.a FROM t1 WHERE t1.a IN (SELECT t2.c FROM t2 WHERE t1.b < t2.d); The query plan might look like: Project [a#0] +- Filter a#0 IN (list#4 [b#1]) : +- Project [c#2] : +- Filter (outer(b#1) < d#3) : +- LocalRelation <empty>, [c#2, d#3] +- LocalRelation <empty>, [a#0, b#1] After `PullupCorrelatedPredicates`, it produces query plan like: 'Project [a#0] +- 'Filter a#0 IN (list#4 [(b#1 < d#3)]) : +- Project [c#2, d#3] : +- LocalRelation <empty>, [c#2, d#3] +- LocalRelation <empty>, [a#0, b#1] Because the correlated predicate involves another attribute `d#3` in subquery, it has been pulled out and added into the `Project` on the top of the subquery. When `list` in `In` contains just one `ListQuery`, `In.checkInputDataTypes` checks if the size of `value` expressions matches the output size of subquery. In the above example, there is only `value` expression and the subquery output has two attributes `c#2, d#3`, so it fails the check and `In.resolved` returns `false`. We should not let `In.checkInputDataTypes` wrongly report unresolved plans to fail the structural integrity check. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <[email protected]> Closes #18968 from viirya/SPARK-21759.
1 parent 9e33954 commit 183d4cb

File tree

7 files changed

+106
-51
lines changed

7 files changed

+106
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,8 +1286,10 @@ class Analyzer(
12861286
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
12871287
case e @ Exists(sub, _, exprId) if !sub.resolved =>
12881288
resolveSubQuery(e, plans)(Exists(_, _, exprId))
1289-
case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
1290-
val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId))
1289+
case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
1290+
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
1291+
ListQuery(plan, exprs, exprId, plan.output)
1292+
})
12911293
In(value, Seq(expr))
12921294
}
12931295
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ object TypeCoercion {
402402

403403
// Handle type casting required between value expression and subquery output
404404
// in IN subquery.
405-
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
405+
case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
406406
if !i.resolved && flattenExpr(a).length == sub.output.length =>
407407
// LHS is the value expression of IN subquery.
408408
val lhs = flattenExpr(a)
@@ -434,7 +434,8 @@ object TypeCoercion {
434434
case _ => CreateStruct(castedLhs)
435435
}
436436

437-
In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
437+
val newSub = Project(castedRhs, sub)
438+
In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
438439
} else {
439440
i
440441
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -138,32 +138,33 @@ case class Not(child: Expression)
138138
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
139139

140140
require(list != null, "list should not be null")
141+
141142
override def checkInputDataTypes(): TypeCheckResult = {
142-
list match {
143-
case ListQuery(sub, _, _) :: Nil =>
144-
val valExprs = value match {
145-
case cns: CreateNamedStruct => cns.valExprs
146-
case expr => Seq(expr)
147-
}
148-
if (valExprs.length != sub.output.length) {
149-
TypeCheckResult.TypeCheckFailure(
150-
s"""
151-
|The number of columns in the left hand side of an IN subquery does not match the
152-
|number of columns in the output of subquery.
153-
|#columns in left hand side: ${valExprs.length}.
154-
|#columns in right hand side: ${sub.output.length}.
155-
|Left side columns:
156-
|[${valExprs.map(_.sql).mkString(", ")}].
157-
|Right side columns:
158-
|[${sub.output.map(_.sql).mkString(", ")}].
159-
""".stripMargin)
160-
} else {
161-
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
162-
case (l, r) if l.dataType != r.dataType =>
163-
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
164-
case _ => None
143+
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
144+
if (mismatchOpt.isDefined) {
145+
list match {
146+
case ListQuery(_, _, _, childOutputs) :: Nil =>
147+
val valExprs = value match {
148+
case cns: CreateNamedStruct => cns.valExprs
149+
case expr => Seq(expr)
165150
}
166-
if (mismatchedColumns.nonEmpty) {
151+
if (valExprs.length != childOutputs.length) {
152+
TypeCheckResult.TypeCheckFailure(
153+
s"""
154+
|The number of columns in the left hand side of an IN subquery does not match the
155+
|number of columns in the output of subquery.
156+
|#columns in left hand side: ${valExprs.length}.
157+
|#columns in right hand side: ${childOutputs.length}.
158+
|Left side columns:
159+
|[${valExprs.map(_.sql).mkString(", ")}].
160+
|Right side columns:
161+
|[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
162+
} else {
163+
val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
164+
case (l, r) if l.dataType != r.dataType =>
165+
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
166+
case _ => None
167+
}
167168
TypeCheckResult.TypeCheckFailure(
168169
s"""
169170
|The data type of one or more elements in the left hand side of an IN subquery
@@ -173,20 +174,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
173174
|Left side:
174175
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
175176
|Right side:
176-
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
177-
""".stripMargin)
178-
} else {
179-
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
177+
|[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
180178
}
181-
}
182-
case _ =>
183-
val mismatchOpt = list.find(l => l.dataType != value.dataType)
184-
if (mismatchOpt.isDefined) {
179+
case _ =>
185180
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
186181
s"${value.dataType} != ${mismatchOpt.get.dataType}")
187-
} else {
188-
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
189-
}
182+
}
183+
} else {
184+
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
190185
}
191186
}
192187

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,17 +274,24 @@ object ScalarSubquery {
274274
case class ListQuery(
275275
plan: LogicalPlan,
276276
children: Seq[Expression] = Seq.empty,
277-
exprId: ExprId = NamedExpression.newExprId)
277+
exprId: ExprId = NamedExpression.newExprId,
278+
childOutputs: Seq[Attribute] = Seq.empty)
278279
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
279-
override def dataType: DataType = plan.schema.fields.head.dataType
280+
override def dataType: DataType = if (childOutputs.length > 1) {
281+
childOutputs.toStructType
282+
} else {
283+
childOutputs.head.dataType
284+
}
285+
override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
280286
override def nullable: Boolean = false
281287
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
282288
override def toString: String = s"list#${exprId.id} $conditionString"
283289
override lazy val canonicalized: Expression = {
284290
ListQuery(
285291
plan.canonicalized,
286292
children.map(_.canonicalized),
287-
ExprId(0))
293+
ExprId(0),
294+
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
288295
}
289296
}
290297

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
6868
case (p, Not(Exists(sub, conditions, _))) =>
6969
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
7070
Join(outerPlan, sub, LeftAnti, joinCond)
71-
case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
71+
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
7272
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
7373
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
7474
Join(outerPlan, sub, LeftSemi, joinCond)
75-
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
75+
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
7676
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
7777
// Construct the condition. A NULL in one of the conditions is regarded as a positive
7878
// result; such a row will be filtered out by the Anti-Join operator.
@@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
116116
val exists = AttributeReference("exists", BooleanType, nullable = false)()
117117
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
118118
exists
119-
case In(value, Seq(ListQuery(sub, conditions, _))) =>
119+
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
120120
val exists = AttributeReference("exists", BooleanType, nullable = false)()
121121
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
122122
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
@@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
227227
case Exists(sub, children, exprId) if children.nonEmpty =>
228228
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
229229
Exists(newPlan, newCond, exprId)
230-
case ListQuery(sub, _, exprId) =>
230+
case ListQuery(sub, _, exprId, childOutputs) =>
231231
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
232-
ListQuery(newPlan, newCond, exprId)
232+
ListQuery(newPlan, newCond, exprId, childOutputs)
233233
}
234234
}
235235

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
27+
class PullupCorrelatedPredicatesSuite extends PlanTest {
28+
29+
object Optimize extends RuleExecutor[LogicalPlan] {
30+
val batches =
31+
Batch("PullupCorrelatedPredicates", Once,
32+
PullupCorrelatedPredicates) :: Nil
33+
}
34+
35+
val testRelation = LocalRelation('a.int, 'b.double)
36+
val testRelation2 = LocalRelation('c.int, 'd.double)
37+
38+
test("PullupCorrelatedPredicates should not produce unresolved plan") {
39+
val correlatedSubquery =
40+
testRelation2
41+
.where('b < 'd)
42+
.select('c)
43+
val outerQuery =
44+
testRelation
45+
.where(In('a, Seq(ListQuery(correlatedSubquery))))
46+
.select('a).analyze
47+
assert(outerQuery.resolved)
48+
49+
val optimized = Optimize.execute(outerQuery)
50+
assert(optimized.resolved)
51+
}
52+
}

sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ number of columns in the output of subquery.
8080
Left side columns:
8181
[t1.`t1a`].
8282
Right side columns:
83-
[t2.`t2a`, t2.`t2b`].
84-
;
83+
[t2.`t2a`, t2.`t2b`].;
8584

8685

8786
-- !query 6
@@ -102,5 +101,4 @@ number of columns in the output of subquery.
102101
Left side columns:
103102
[t1.`t1a`, t1.`t1b`].
104103
Right side columns:
105-
[t2.`t2a`].
106-
;
104+
[t2.`t2a`].;

0 commit comments

Comments
 (0)