diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b78bdf082f33..e1bb36a4b638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1296,6 +1296,31 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) } + private def split(condition: Seq[Expression], + joinType: JoinType, + left: LogicalPlan, + right: LogicalPlan) = { + val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic) + + joinType match { + case _: InnerLike | LeftSemi | RightOuter => + val (leftEvaluateCondition, rest) = + pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) + val (rightEvaluateCondition, commonCondition) = + rest.partition(expr => expr.references.subsetOf(right.outputSet)) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) + case LeftOuter | LeftAnti | ExistenceJoin(_) => + val (rightEvaluateCondition, rest) = + pushDownCandidates.partition(_.references.subsetOf(right.outputSet)) + val (leftEvaluateCondition, commonCondition) = + rest.partition(expr => expr.references.subsetOf(left.outputSet)) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) + case FullOuter => (null, null, null) + case NaturalJoin(_) => (null, null, null) + case UsingJoin(_, _) => (null, null, null) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { @@ -1348,7 +1373,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the join filter into sub query scanning if applicable case j @ Join(left, right, joinType, joinCondition, hint) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = - split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) + split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), joinType, left, right) joinType match { case _: InnerLike | LeftSemi => @@ -1376,7 +1401,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, joinType, newJoinCond, hint) - case FullOuter => j + case FullOuter => + j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e05af08dfb74..d07fadb8a109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -352,6 +352,96 @@ class SubquerySuite extends QueryTest with SharedSparkSession { } } + + test("SPARK-29769: JOIN Condition use EXISTS/NOT EXISTS") { + withTempView("s1", "s2", "s3") { + Seq(1, 3, 5, 7, 9).toDF("id").createOrReplaceTempView("s1") + Seq(1, 3, 4, 6, 9).toDF("id").createOrReplaceTempView("s2") + Seq(3, 4, 6, 9).toDF("id").createOrReplaceTempView("s3") + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | JOIN s2 ON s1.id = s2.id + | AND EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(1) :: Row(3) :: Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id as id2 FROM s1 + | RIGHT OUTER JOIN s2 ON s1.id = s2.id + | AND EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(1, 1) :: Row(3, 3) :: Row(null, 4) :: Row(null, 6) :: Row(9, 9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id as id2 FROM s1 + | RIGHT OUTER JOIN s2 ON s1.id = s2.id + | AND NOT EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(null, 1) :: Row(null, 3) :: Row(null, 4) :: Row(null, 6) :: Row(null, 9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | LEFT SEMI JOIN s2 ON s1.id = s2.id + | AND EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(1) :: Row(3) :: Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | LEFT SEMI JOIN s2 ON s1.id = s2.id + | AND NOT EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | LEFT ANTI JOIN s2 ON s1.id = s2.id + | AND EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(5) :: Row(7) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | LEFT ANTI JOIN s2 ON s1.id = s2.id + | AND NOT EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(1) :: Row(3):: Row(5) :: Row(7) :: Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id as id2 FROM s1 + | LEFT OUTER JOIN s2 ON s1.id = s2.id + | AND EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(1, 1) :: Row(3, 3) :: Row(5, null) :: Row(7, null) :: Row(9, 9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id as id2 FROM s1 + | LEFT OUTER JOIN s2 ON s1.id = s2.id + | AND NOT EXISTS (SELECT * from s3 where s3.id > 6) + """.stripMargin), + Row(1, null) :: Row(3, null) :: Row(5, null) :: Row(7, null) :: Row(9, null) :: Nil) + } + } + test("SPARK-14791: scalar subquery inside broadcast join") { val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)") val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil