Skip to content

Commit da88592

Browse files
hvanhovelldavies
authored andcommitted
[SPARK-4226] [SQL] Support IN/EXISTS Subqueries
### What changes were proposed in this pull request? This PR adds support for in/exists predicate subqueries to Spark. Predicate sub-queries are used as a filtering condition in a query (this is the only supported use case). A predicate sub-query comes in two forms: - `[NOT] EXISTS(subquery)` - `[NOT] IN (subquery)` This PR is (loosely) based on the work of davies (#10706) and chenghao-intel (#9055). They should be credited for the work they did. ### How was this patch tested? Modified parsing unit tests. Added tests to `org.apache.spark.sql.SQLQuerySuite` cc rxin, davies & chenghao-intel Author: Herman van Hovell <[email protected]> Closes #12306 from hvanhovell/SPARK-4226.
1 parent 3c91afe commit da88592

File tree

12 files changed

+476
-42
lines changed

12 files changed

+476
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -855,25 +855,35 @@ class Analyzer(
855855
}
856856

857857
/**
858-
* This rule resolve subqueries inside expressions.
858+
* This rule resolves sub-queries inside expressions.
859859
*
860-
* Note: CTE are handled in CTESubstitution.
860+
* Note: CTEs are handled in CTESubstitution.
861861
*/
862862
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
863863

864-
private def hasSubquery(e: Expression): Boolean = {
865-
e.find(_.isInstanceOf[SubqueryExpression]).isDefined
866-
}
867-
868-
private def hasSubquery(q: LogicalPlan): Boolean = {
869-
q.expressions.exists(hasSubquery)
864+
/**
865+
* Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a
866+
* sub-query by using the plan the predicates should be correlated to.
867+
*/
868+
private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = {
869+
q transformUp {
870+
case f @ Filter(cond, child) if child.resolved && !f.resolved =>
871+
val newCond = resolveExpression(cond, p, throws = false)
872+
if (!cond.fastEquals(newCond)) {
873+
Filter(newCond, child)
874+
} else {
875+
f
876+
}
877+
}
870878
}
871879

872880
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
873-
case q: LogicalPlan if q.childrenResolved && hasSubquery(q) =>
881+
case q: LogicalPlan if q.childrenResolved =>
874882
q transformExpressions {
875883
case e: SubqueryExpression if !e.query.resolved =>
876-
e.withNewPlan(execute(e.query))
884+
// First resolve as much of the sub-query as possible. After that we use the children of
885+
// this plan to resolve the remaining correlated predicates.
886+
e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates))
877887
}
878888
}
879889
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.analysis
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23-
import org.apache.spark.sql.catalyst.plans.UsingJoin
23+
import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.types._
2626

2727
/**
2828
* Throws user facing errors when passed invalid queries that fail to analyze.
2929
*/
30-
trait CheckAnalysis {
30+
trait CheckAnalysis extends PredicateHelper {
3131

3232
/**
3333
* Override to provide additional checks for correct analysis.
@@ -110,6 +110,39 @@ trait CheckAnalysis {
110110
s"filter expression '${f.condition.sql}' " +
111111
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
112112

113+
case f @ Filter(condition, child) =>
114+
// Make sure that no correlated reference is below Aggregates, Outer Joins and on the
115+
// right hand side of Unions.
116+
lazy val attributes = child.outputSet
117+
def failOnCorrelatedReference(
118+
p: LogicalPlan,
119+
message: String): Unit = p.transformAllExpressions {
120+
case e: NamedExpression if attributes.contains(e) =>
121+
failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
122+
}
123+
def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach {
124+
case a @ Aggregate(_, _, source) =>
125+
failOnCorrelatedReference(source, "an AGGREATE")
126+
case j @ Join(left, _, RightOuter, _) =>
127+
failOnCorrelatedReference(left, "a RIGHT OUTER JOIN")
128+
case j @ Join(_, right, jt, _) if jt != Inner =>
129+
failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN")
130+
case Union(_ :: xs) =>
131+
xs.foreach(failOnCorrelatedReference(_, "a UNION"))
132+
case s: SetOperation =>
133+
failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT")
134+
case _ =>
135+
}
136+
splitConjunctivePredicates(condition).foreach {
137+
case p: PredicateSubquery =>
138+
checkForCorrelatedReferences(p)
139+
case Not(p: PredicateSubquery) =>
140+
checkForCorrelatedReferences(p)
141+
case e if PredicateSubquery.hasPredicateSubquery(e) =>
142+
failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e")
143+
case e =>
144+
}
145+
113146
case j @ Join(_, _, UsingJoin(_, cols), _) =>
114147
val from = operator.inputSet.map(_.name).mkString(", ")
115148
failAnalysis(
@@ -209,6 +242,9 @@ trait CheckAnalysis {
209242
| but one table has '${firstError.output.length}' columns and another table has
210243
| '${s.children.head.output.length}' columns""".stripMargin)
211244

245+
case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
246+
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
247+
212248
case _ => // Fallbacks to the following checks
213249
}
214250

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.plans.QueryPlan
2222
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
23-
import org.apache.spark.sql.types.DataType
23+
import org.apache.spark.sql.types._
2424

2525
/**
2626
* An interface for subquery that is used in expressions.
2727
*/
28-
abstract class SubqueryExpression extends LeafExpression {
28+
abstract class SubqueryExpression extends Expression {
2929

3030
/**
3131
* The logical plan of the query.
@@ -61,6 +61,8 @@ case class ScalarSubquery(
6161

6262
override def dataType: DataType = query.schema.fields.head.dataType
6363

64+
override def children: Seq[Expression] = Nil
65+
6466
override def checkInputDataTypes(): TypeCheckResult = {
6567
if (query.schema.length != 1) {
6668
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
@@ -77,3 +79,81 @@ case class ScalarSubquery(
7779

7880
override def toString: String = s"subquery#${exprId.id}"
7981
}
82+
83+
/**
84+
* A predicate subquery checks the existence of a value in a sub-query. We currently only allow
85+
* [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will
86+
* be rewritten into a left semi/anti join during analysis.
87+
*/
88+
abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate {
89+
override def nullable: Boolean = false
90+
override def plan: LogicalPlan = SubqueryAlias(prettyName, query)
91+
}
92+
93+
object PredicateSubquery {
94+
def hasPredicateSubquery(e: Expression): Boolean = {
95+
e.find(_.isInstanceOf[PredicateSubquery]).isDefined
96+
}
97+
}
98+
99+
/**
100+
* The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL):
101+
* {{{
102+
* SELECT *
103+
* FROM a
104+
* WHERE a.id IN (SELECT id
105+
* FROM b)
106+
* }}}
107+
*/
108+
case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery {
109+
override def children: Seq[Expression] = value :: Nil
110+
override lazy val resolved: Boolean = value.resolved && query.resolved
111+
override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan)
112+
113+
/**
114+
* The unwrapped value side expressions.
115+
*/
116+
lazy val expressions: Seq[Expression] = value match {
117+
case CreateStruct(cols) => cols
118+
case col => Seq(col)
119+
}
120+
121+
/**
122+
* Check if the number of columns and the data types on both sides match.
123+
*/
124+
override def checkInputDataTypes(): TypeCheckResult = {
125+
// Check the number of arguments.
126+
if (expressions.length != query.output.length) {
127+
TypeCheckResult.TypeCheckFailure(
128+
s"The number of fields in the value (${expressions.length}) does not match with " +
129+
s"the number of columns in the subquery (${query.output.length})")
130+
}
131+
132+
// Check the argument types.
133+
expressions.zip(query.output).zipWithIndex.foreach {
134+
case ((e, a), i) if e.dataType != a.dataType =>
135+
TypeCheckResult.TypeCheckFailure(
136+
s"The data type of value[$i](${e.dataType}) does not match " +
137+
s"subquery column '${a.name}' (${a.dataType}).")
138+
case _ =>
139+
}
140+
141+
TypeCheckResult.TypeCheckSuccess
142+
}
143+
}
144+
145+
/**
146+
* The [[Exists]] expression checks if a row exists in a subquery given some correlated condition.
147+
* For example (SQL):
148+
* {{{
149+
* SELECT *
150+
* FROM a
151+
* WHERE EXISTS (SELECT *
152+
* FROM b
153+
* WHERE b.id = a.id)
154+
* }}}
155+
*/
156+
case class Exists(query: LogicalPlan) extends PredicateSubquery {
157+
override def children: Seq[Expression] = Nil
158+
override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan)
159+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.annotation.tailrec
2121
import scala.collection.immutable.HashSet
22+
import scala.collection.mutable
2223

2324
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
2425
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
2526
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
26-
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _}
2728
import org.apache.spark.sql.catalyst.expressions.aggregate._
2829
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2930
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
@@ -47,6 +48,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
4748
// However, because we also use the analyzer to canonicalized queries (for view definition),
4849
// we do not eliminate subqueries or compute current time in the analyzer.
4950
Batch("Finish Analysis", Once,
51+
RewritePredicateSubquery,
5052
EliminateSubqueryAliases,
5153
ComputeCurrentTime,
5254
GetCurrentDatabase(sessionCatalog),
@@ -1446,3 +1448,114 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] {
14461448
}
14471449
}
14481450
}
1451+
1452+
/**
1453+
* This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates
1454+
* are supported:
1455+
* a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter
1456+
* will be pulled out as the join conditions.
1457+
* b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will
1458+
* be pulled out as join conditions, value = selected column will also be used as join
1459+
* condition.
1460+
*/
1461+
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
1462+
/**
1463+
* Pull out all correlated predicates from a given sub-query. This method removes the correlated
1464+
* predicates from sub-query [[Filter]]s and adds the references of these predicates to
1465+
* all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the
1466+
* predicates in the join condition.
1467+
*
1468+
* This method returns the rewritten sub-query and the combined (AND) extracted predicate.
1469+
*/
1470+
private def pullOutCorrelatedPredicates(
1471+
subquery: LogicalPlan,
1472+
query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
1473+
val references = query.outputSet
1474+
val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]]
1475+
val transformed = subquery transformUp {
1476+
case f @ Filter(cond, child) =>
1477+
// Find all correlated predicates.
1478+
val (correlated, local) = splitConjunctivePredicates(cond).partition { e =>
1479+
e.references.intersect(references).nonEmpty
1480+
}
1481+
// Rewrite the filter without the correlated predicates if any.
1482+
correlated match {
1483+
case Nil => f
1484+
case xs if local.nonEmpty =>
1485+
val newFilter = Filter(local.reduce(And), child)
1486+
predicateMap += newFilter -> correlated
1487+
newFilter
1488+
case xs =>
1489+
predicateMap += child -> correlated
1490+
child
1491+
}
1492+
case p @ Project(expressions, child) =>
1493+
// Find all pulled out predicates defined in the Project's subtree.
1494+
val localPredicates = p.collect(predicateMap).flatten
1495+
1496+
// Determine which correlated predicate references are missing from this project.
1497+
val localPredicateReferences = localPredicates
1498+
.map(_.references)
1499+
.reduceOption(_ ++ _)
1500+
.getOrElse(AttributeSet.empty)
1501+
val missingReferences = localPredicateReferences -- p.references -- query.outputSet
1502+
1503+
// Create a new project if we need to add missing references.
1504+
if (missingReferences.nonEmpty) {
1505+
Project(expressions ++ missingReferences, child)
1506+
} else {
1507+
p
1508+
}
1509+
}
1510+
(transformed, predicateMap.values.flatten.toSeq)
1511+
}
1512+
1513+
/**
1514+
* Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by
1515+
* constructing the required join condition. Both the rewritten subquery and the constructed
1516+
* join condition are returned.
1517+
*/
1518+
private def pullOutCorrelatedPredicates(
1519+
in: InSubQuery,
1520+
query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
1521+
val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query)
1522+
val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
1523+
(resolved, conditions)
1524+
}
1525+
1526+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1527+
case f @ Filter(condition, child) =>
1528+
val (withSubquery, withoutSubquery) =
1529+
splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
1530+
1531+
// Construct the pruned filter condition.
1532+
val newFilter: LogicalPlan = withoutSubquery match {
1533+
case Nil => child
1534+
case conditions => Filter(conditions.reduce(And), child)
1535+
}
1536+
1537+
// Filter the plan by applying left semi and left anti joins.
1538+
withSubquery.foldLeft(newFilter) {
1539+
case (p, Exists(sub)) =>
1540+
val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
1541+
Join(p, resolved, LeftSemi, conditions.reduceOption(And))
1542+
case (p, Not(Exists(sub))) =>
1543+
val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
1544+
Join(p, resolved, LeftAnti, conditions.reduceOption(And))
1545+
case (p, in: InSubQuery) =>
1546+
val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
1547+
Join(p, resolved, LeftSemi, conditions.reduceOption(And))
1548+
case (p, Not(in: InSubQuery)) =>
1549+
val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
1550+
// This is a NULL-aware (left) anti join (NAAJ).
1551+
// Construct the condition. A NULL in one of the conditions is regarded as a positive
1552+
// result; such a row will be filtered out by the Anti-Join operator.
1553+
val anyNull = conditions.map(IsNull).reduceLeft(Or)
1554+
val condition = conditions.reduceLeft(And)
1555+
1556+
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
1557+
// if performance matters to you.
1558+
Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
1559+
}
1560+
}
1561+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
391391

392392
// Having
393393
val withHaving = withProject.optional(having) {
394-
// Note that we added a cast to boolean. If the expression itself is already boolean,
395-
// the optimizer will get rid of the unnecessary cast.
396-
Filter(Cast(expression(having), BooleanType), withProject)
394+
// Note that we add a cast to non-predicate expressions. If the expression itself is
395+
// already boolean, the optimizer will get rid of the unnecessary cast.
396+
val predicate = expression(having) match {
397+
case p: Predicate => p
398+
case e => Cast(e, BooleanType)
399+
}
400+
Filter(predicate, withProject)
397401
}
398402

399403
// Distinct
@@ -866,10 +870,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
866870
}
867871

868872
/**
869-
* Create a filtering correlated sub-query. This is not supported yet.
873+
* Create a filtering correlated sub-query (EXISTS).
870874
*/
871875
override def visitExists(ctx: ExistsContext): Expression = {
872-
throw new ParseException("EXISTS clauses are not supported.", ctx)
876+
Exists(plan(ctx.query))
873877
}
874878

875879
/**
@@ -944,7 +948,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
944948
GreaterThanOrEqual(e, expression(ctx.lower)),
945949
LessThanOrEqual(e, expression(ctx.upper))))
946950
case SqlBaseParser.IN if ctx.query != null =>
947-
throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
951+
invertIfNotDefined(InSubQuery(e, plan(ctx.query)))
948952
case SqlBaseParser.IN =>
949953
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
950954
case SqlBaseParser.LIKE =>

0 commit comments

Comments
 (0)