Skip to content

Commit 2b7bfba

Browse files
committed
[SPARK-26078][SQL][BACKPORT-2.4] Dedup self-join attributes on IN subqueries
When there is a self-join as result of a IN subquery, the join condition may be invalid, resulting in trivially true predicates and return wrong results. The PR deduplicates the subquery output in order to avoid the issue. added UT Closes #23057 from mgaido91/SPARK-26078. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent c0f4082 commit 2b7bfba

File tree

2 files changed

+97
-38
lines changed

2 files changed

+97
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
4343
* condition.
4444
*/
4545
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
46-
private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
46+
47+
private def buildJoin(
48+
outerPlan: LogicalPlan,
49+
subplan: LogicalPlan,
50+
joinType: JoinType,
51+
condition: Option[Expression]): Join = {
52+
// Deduplicate conflicting attributes if any.
53+
val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition)
54+
Join(outerPlan, dedupSubplan, joinType, condition)
55+
}
56+
57+
private def dedupSubqueryOnSelfJoin(
58+
outerPlan: LogicalPlan,
59+
subplan: LogicalPlan,
60+
valuesOpt: Option[Seq[Expression]],
61+
condition: Option[Expression] = None): LogicalPlan = {
4762
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
4863
// the produced join then becomes unresolved and break structural integrity. We should
49-
// de-duplicate conflicting attributes. We don't use transformation here because we only
50-
// care about the most top join converted from correlated predicate subquery.
51-
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
52-
val duplicates = right.outputSet.intersect(left.outputSet)
53-
if (duplicates.nonEmpty) {
54-
val aliasMap = AttributeMap(duplicates.map { dup =>
55-
dup -> Alias(dup, dup.toString)()
56-
}.toSeq)
57-
val aliasedExpressions = right.output.map { ref =>
58-
aliasMap.getOrElse(ref, ref)
59-
}
60-
val newRight = Project(aliasedExpressions, right)
61-
val newJoinCond = joinCond.map { condExpr =>
62-
condExpr transform {
63-
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
64+
// de-duplicate conflicting attributes.
65+
// SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer
66+
// values. In this case, the resulting join would contain trivially true conditions (eg.
67+
// id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting
68+
// attributes in the join condition, the subquery's conflicting attributes are changed using
69+
// a projection which aliases them and resolves the problem.
70+
val outerReferences = valuesOpt.map(values =>
71+
AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty)
72+
val outerRefs = outerPlan.outputSet ++ outerReferences
73+
val duplicates = outerRefs.intersect(subplan.outputSet)
74+
if (duplicates.nonEmpty) {
75+
condition.foreach { e =>
76+
val conflictingAttrs = e.references.intersect(duplicates)
77+
if (conflictingAttrs.nonEmpty) {
78+
throw new AnalysisException("Found conflicting attributes " +
79+
s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " +
80+
s"$outerPlan\nand subplan:\n $subplan")
6481
}
65-
}
66-
Join(left, newRight, joinType, newJoinCond)
67-
} else {
68-
j
6982
}
70-
case _ => joinPlan
83+
val rewrites = AttributeMap(duplicates.map { dup =>
84+
dup -> Alias(dup, dup.toString)()
85+
}.toSeq)
86+
val aliasedExpressions = subplan.output.map { ref =>
87+
rewrites.getOrElse(ref, ref)
88+
}
89+
Project(aliasedExpressions, subplan)
90+
} else {
91+
subplan
92+
}
7193
}
7294

7395
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -85,25 +107,27 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
85107
withSubquery.foldLeft(newFilter) {
86108
case (p, Exists(sub, conditions, _)) =>
87109
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
88-
// Deduplicate conflicting attributes if any.
89-
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
110+
buildJoin(outerPlan, sub, LeftSemi, joinCond)
90111
case (p, Not(Exists(sub, conditions, _))) =>
91112
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
92-
// Deduplicate conflicting attributes if any.
93-
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
113+
buildJoin(outerPlan, sub, LeftAnti, joinCond)
94114
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
95-
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
96-
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
97115
// Deduplicate conflicting attributes if any.
98-
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
116+
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
117+
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
118+
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
119+
Join(outerPlan, newSub, LeftSemi, joinCond)
99120
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
100121
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
101122
// Construct the condition. A NULL in one of the conditions is regarded as a positive
102123
// result; such a row will be filtered out by the Anti-Join operator.
103124

104125
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
105126
// Use EXISTS if performance matters to you.
106-
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
127+
128+
// Deduplicate conflicting attributes if any.
129+
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
130+
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
107131
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
108132
// Expand the NOT IN expression with the NULL-aware semantic
109133
// to its full form. That is from:
@@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
118142
// will have the final conditions in the LEFT ANTI as
119143
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
120144
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
121-
// Deduplicate conflicting attributes if any.
122-
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
145+
Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
123146
case (p, predicate) =>
124147
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
125148
Project(p.output, Filter(newCond.get, inputPlan))
@@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
140163
e transformUp {
141164
case Exists(sub, conditions, _) =>
142165
val exists = AttributeReference("exists", BooleanType, nullable = false)()
143-
// Deduplicate conflicting attributes if any.
144-
newPlan = dedupJoin(
145-
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
166+
newPlan =
167+
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
146168
exists
147169
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
148170
val exists = AttributeReference("exists", BooleanType, nullable = false)()
149-
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
150-
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
151171
// Deduplicate conflicting attributes if any.
152-
newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
172+
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
173+
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
174+
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
175+
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
153176
exists
154177
}
155178
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,4 +1268,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
12681268
assert(getNumSortsInQuery(query5) == 1)
12691269
}
12701270
}
1271+
1272+
test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
1273+
withTempView("a", "b") {
1274+
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
1275+
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
1276+
1277+
val df1 = spark.sql(
1278+
"""
1279+
|SELECT id,num,source FROM (
1280+
| SELECT id, num, 'a' as source FROM a
1281+
| UNION ALL
1282+
| SELECT id, num, 'b' as source FROM b
1283+
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
1284+
""".stripMargin)
1285+
checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
1286+
val df2 = spark.sql(
1287+
"""
1288+
|SELECT id,num,source FROM (
1289+
| SELECT id, num, 'a' as source FROM a
1290+
| UNION ALL
1291+
| SELECT id, num, 'b' as source FROM b
1292+
|) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
1293+
""".stripMargin)
1294+
checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
1295+
val df3 = spark.sql(
1296+
"""
1297+
|SELECT id,num,source FROM (
1298+
| SELECT id, num, 'a' as source FROM a
1299+
| UNION ALL
1300+
| SELECT id, num, 'b' as source FROM b
1301+
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
1302+
|c.id IN (SELECT id FROM b WHERE num = 3)
1303+
""".stripMargin)
1304+
checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
1305+
}
1306+
}
12711307
}

0 commit comments

Comments
 (0)