Skip to content

Commit 540e912

Browse files
sameeragarwalyhuai
authored andcommitted
[SPARK-17244] Catalyst should not pushdown non-deterministic join conditions
## What changes were proposed in this pull request? Given that non-deterministic expressions can be stateful, pushing them down the query plan during the optimization phase can cause incorrect behavior. This patch fixes that issue by explicitly disabling that. ## How was this patch tested? A new test in `FilterPushdownSuite` that checks catalyst behavior for both deterministic and non-deterministic join conditions. Author: Sameer Agarwal <[email protected]> Closes #14815 from sameeragarwal/constraint-inputfile.
1 parent f64a1dd commit 540e912

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,18 +1379,25 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
13791379
*/
13801380
object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13811381
/**
1382-
* Splits join condition expressions into three categories based on the attributes required
1383-
* to evaluate them.
1382+
* Splits join condition expressions or filter predicates (on a given join's output) into three
1383+
* categories based on the attributes required to evaluate them. Note that we explicitly exclude
1384+
* on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or
1385+
* canEvaluateInRight to prevent pushing these predicates on either side of the join.
13841386
*
13851387
* @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
13861388
*/
13871389
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
1390+
// Note: In order to ensure correctness, it's important to not change the relative ordering of
1391+
// any deterministic expression that follows a non-deterministic expression. To achieve this,
1392+
// we only consider pushing down those expressions that precede the first non-deterministic
1393+
// expression in the condition.
1394+
val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic)
13881395
val (leftEvaluateCondition, rest) =
1389-
condition.partition(_.references subsetOf left.outputSet)
1396+
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
13901397
val (rightEvaluateCondition, commonCondition) =
1391-
rest.partition(_.references subsetOf right.outputSet)
1398+
rest.partition(expr => expr.references.subsetOf(right.outputSet))
13921399

1393-
(leftEvaluateCondition, rightEvaluateCondition, commonCondition)
1400+
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic)
13941401
}
13951402

13961403
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -1441,7 +1448,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
14411448
}
14421449

14431450
// push down the join filter into sub query scanning if applicable
1444-
case f @ Join(left, right, joinType, joinCondition) =>
1451+
case j @ Join(left, right, joinType, joinCondition) =>
14451452
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
14461453
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
14471454

@@ -1471,7 +1478,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
14711478
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
14721479

14731480
Join(newLeft, newRight, LeftOuter, newJoinCond)
1474-
case FullOuter => f
1481+
case FullOuter => j
14751482
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
14761483
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
14771484
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,4 +987,18 @@ class FilterPushdownSuite extends PlanTest {
987987

988988
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer)
989989
}
990+
991+
test("join condition pushdown: deterministic and non-deterministic") {
992+
val x = testRelation.subquery('x)
993+
val y = testRelation.subquery('y)
994+
995+
// Verify that all conditions preceding the first non-deterministic condition are pushed down
996+
// by the optimizer and others are not.
997+
val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 &&
998+
"x.a".attr === Rand(10) && "y.b".attr === 5))
999+
val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5),
1000+
condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5))
1001+
1002+
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
1003+
}
9901004
}

0 commit comments

Comments
 (0)