Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,31 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
}

private def split(condition: Seq[Expression],
joinType: JoinType,
left: LogicalPlan,
right: LogicalPlan) = {
val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic)

joinType match {
case _: InnerLike | LeftSemi | RightOuter =>
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, commonCondition) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
val (rightEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(right.outputSet))
val (leftEvaluateCondition, commonCondition) =
rest.partition(expr => expr.references.subsetOf(left.outputSet))
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
case FullOuter => (null, null, null)
case NaturalJoin(_) => (null, null, null)
case UsingJoin(_, _) => (null, null, null)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
Expand Down Expand Up @@ -1348,7 +1373,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
// push down the join filter into sub query scanning if applicable
case j @ Join(left, right, joinType, joinCondition, hint) =>
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), joinType, left, right)

joinType match {
case _: InnerLike | LeftSemi =>
Expand Down Expand Up @@ -1376,7 +1401,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Join(newLeft, newRight, joinType, newJoinCond, hint)
case FullOuter => j
case FullOuter =>
j
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}
Expand Down
90 changes: 90 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,96 @@ class SubquerySuite extends QueryTest with SharedSparkSession {
}
}


test("SPARK-29769: JOIN Condition use EXISTS/NOT EXISTS") {
withTempView("s1", "s2", "s3") {
Seq(1, 3, 5, 7, 9).toDF("id").createOrReplaceTempView("s1")
Seq(1, 3, 4, 6, 9).toDF("id").createOrReplaceTempView("s2")
Seq(3, 4, 6, 9).toDF("id").createOrReplaceTempView("s3")

checkAnswer(
sql(
"""
| SELECT s1.id FROM s1
| JOIN s2 ON s1.id = s2.id
| AND EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(1) :: Row(3) :: Row(9) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id, s2.id as id2 FROM s1
| RIGHT OUTER JOIN s2 ON s1.id = s2.id
| AND EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(1, 1) :: Row(3, 3) :: Row(null, 4) :: Row(null, 6) :: Row(9, 9) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id, s2.id as id2 FROM s1
| RIGHT OUTER JOIN s2 ON s1.id = s2.id
| AND NOT EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(null, 1) :: Row(null, 3) :: Row(null, 4) :: Row(null, 6) :: Row(null, 9) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id FROM s1
| LEFT SEMI JOIN s2 ON s1.id = s2.id
| AND EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(1) :: Row(3) :: Row(9) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id FROM s1
| LEFT SEMI JOIN s2 ON s1.id = s2.id
| AND NOT EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Nil)

checkAnswer(
sql(
"""
| SELECT s1.id FROM s1
| LEFT ANTI JOIN s2 ON s1.id = s2.id
| AND EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(5) :: Row(7) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id FROM s1
| LEFT ANTI JOIN s2 ON s1.id = s2.id
| AND NOT EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(1) :: Row(3):: Row(5) :: Row(7) :: Row(9) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id, s2.id as id2 FROM s1
| LEFT OUTER JOIN s2 ON s1.id = s2.id
| AND EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(1, 1) :: Row(3, 3) :: Row(5, null) :: Row(7, null) :: Row(9, 9) :: Nil)

checkAnswer(
sql(
"""
| SELECT s1.id, s2.id as id2 FROM s1
| LEFT OUTER JOIN s2 ON s1.id = s2.id
| AND NOT EXISTS (SELECT * from s3 where s3.id > 6)
""".stripMargin),
Row(1, null) :: Row(3, null) :: Row(5, null) :: Row(7, null) :: Row(9, null) :: Nil)
}
}

test("SPARK-14791: scalar subquery inside broadcast join") {
val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)")
val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil
Expand Down