Skip to content

Commit 4604a08

Browse files
committed
PullupCorrelatedPredicates should not produce unresolved plans.
1 parent 07549b2 commit 4604a08

File tree

2 files changed

+98
-13
lines changed

2 files changed

+98
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,25 +140,58 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
140140
require(list != null, "list should not be null")
141141
override def checkInputDataTypes(): TypeCheckResult = {
142142
list match {
143-
case ListQuery(sub, _, _) :: Nil =>
143+
case (l @ ListQuery(sub, children, _)) :: Nil =>
144144
val valExprs = value match {
145145
case cns: CreateNamedStruct => cns.valExprs
146146
case expr => Seq(expr)
147147
}
148-
if (valExprs.length != sub.output.length) {
149-
TypeCheckResult.TypeCheckFailure(
148+
149+
// SPARK-21759:
150+
// It is possibly that the subquery plan has more output than value expressions, because
151+
// the condition expressions in `ListQuery` might use part of subquery plan's output.
152+
// For example, in the following query plan, the condition of `ListQuery` uses value#207
153+
// from the subquery query. For now the size of output of subquery is 2, the size of value
154+
// is 1.
155+
//
156+
// Filter key#201 IN (list#200 [(value#207 = min(value)#204)])
157+
// : +- Project [key#206, value#207]
158+
// : +- Filter (value#207 > val_9)
159+
160+
// Take the subset of output which are not going to match with value expressions and also
161+
// not used in condition expressions, if any.
162+
val subqueryOutputNotInCondition = sub.output.drop(valExprs.length).filter { attr =>
163+
l.children.forall { c =>
164+
!c.references.contains(attr)
165+
}
166+
}
167+
168+
val basicErrorMessage =
169+
s"""
170+
|The number of columns in the left hand side of an IN subquery does not match the
171+
|number of columns in the output of subquery.
172+
|#columns in left hand side: ${valExprs.length}.
173+
|#columns in right hand side: ${sub.output.length}.
174+
|Left side columns:
175+
|[${valExprs.map(_.sql).mkString(", ")}].
176+
|Right side columns:
177+
|[${sub.output.map(_.sql).mkString(", ")}].
178+
""".stripMargin
179+
180+
if (valExprs.length > sub.output.length) {
181+
TypeCheckResult.TypeCheckFailure(basicErrorMessage)
182+
} else if (subqueryOutputNotInCondition.nonEmpty) {
183+
val finalErrorMessage = basicErrorMessage +
150184
s"""
151-
|The number of columns in the left hand side of an IN subquery does not match the
152-
|number of columns in the output of subquery.
153-
|#columns in left hand side: ${valExprs.length}.
154-
|#columns in right hand side: ${sub.output.length}.
155-
|Left side columns:
156-
|[${valExprs.map(_.sql).mkString(", ")}].
157-
|Right side columns:
158-
|[${sub.output.map(_.sql).mkString(", ")}].
159-
""".stripMargin)
185+
| The additional output in subquery aren't used in the condition of subquery.
186+
| Additional output:
187+
| [${subqueryOutputNotInCondition.map(_.sql).mkString(", ")}].
188+
| Condition:
189+
| [${children.map(_.sql).mkString(", ")}].
190+
| ${l.references}
191+
""".stripMargin
192+
TypeCheckResult.TypeCheckFailure(finalErrorMessage)
160193
} else {
161-
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
194+
val mismatchedColumns = valExprs.zip(sub.output.take(valExprs.length)).flatMap {
162195
case (l, r) if l.dataType != r.dataType =>
163196
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
164197
case _ => None
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
27+
class PullupCorrelatedPredicatesSuite extends PlanTest {
28+
29+
object Optimize extends RuleExecutor[LogicalPlan] {
30+
val batches =
31+
Batch("PullupCorrelatedPredicates", Once,
32+
PullupCorrelatedPredicates) :: Nil
33+
}
34+
35+
val testRelation = LocalRelation('a.int, 'b.double)
36+
val testRelation2 = LocalRelation('c.int, 'd.double)
37+
38+
test("PullupCorrelatedPredicates should not produce unresolved plan") {
39+
val correlatedSubquery =
40+
testRelation2
41+
.where('b < 'd)
42+
.select('c)
43+
val outerQuery =
44+
testRelation
45+
.where(In('a, Seq(ListQuery(correlatedSubquery))))
46+
.select('a).analyze
47+
assert(outerQuery.resolved)
48+
49+
val optimized = Optimize.execute(outerQuery)
50+
assert(optimized.resolved)
51+
}
52+
}

0 commit comments

Comments
 (0)