Skip to content

Commit ea76e29

Browse files
committed
Code review
1 parent 488eda8 commit ea76e29

File tree

2 files changed

+17
-28
lines changed

2 files changed

+17
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,20 +1121,18 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
11211121
}
11221122
}
11231123

1124+
def getAliasMap(plan: Project): AttributeMap[Expression] = {
1125+
// Create a map of Aliases to their values from the child projection.
1126+
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1127+
AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) })
1128+
}
11241129

1125-
def getAliasMap(plan: LogicalPlan): AttributeMap[Expression] = {
1126-
val aliasMap = plan match {
1127-
case p: Project =>
1128-
// Create a map of Aliases to their values from the child projection.
1129-
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1130-
p.projectList.collect { case a: Alias => (a.toAttribute, a.child) }
1131-
case a: Aggregate =>
1132-
// Find all the aliased expressions in the aggregate list that don't include any actual
1133-
// AggregateExpression, and create a map from the alias to the expression
1134-
a.aggregateExpressions.collect {
1135-
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1136-
(a.toAttribute, a.child)
1137-
}
1130+
def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
1131+
// Find all the aliased expressions in the aggregate list that don't include any actual
1132+
// AggregateExpression, and create a map from the alias to the expression
1133+
val aliasMap = plan.aggregateExpressions.collect {
1134+
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1135+
(a.toAttribute, a.child)
11381136
}
11391137
AttributeMap(aliasMap)
11401138
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
22-
import org.apache.spark.sql.catalyst.plans.LeftSemiOrAnti
21+
import org.apache.spark.sql.catalyst.plans._
2322
import org.apache.spark.sql.catalyst.plans.logical._
2423
import org.apache.spark.sql.catalyst.rules.Rule
2524

@@ -31,14 +30,12 @@ import org.apache.spark.sql.catalyst.rules.Rule
3130
* 4) Aggregate
3231
* 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]].
3332
*/
34-
3533
object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
3634
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
37-
// Similar to the above Filter over Project
3835
// LeftSemi/LeftAnti over Project
3936
case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
4037
if pList.forall(_.deterministic) &&
41-
!pList.find(ScalarSubquery.hasScalarSubquery(_)).isDefined &&
38+
!pList.exists(ScalarSubquery.hasScalarSubquery)&&
4239
canPushThroughCondition(Seq(gChild), joinCond, rightOp) =>
4340
if (joinCond.isEmpty) {
4441
// No join condition, just push down the Join below Project
@@ -53,20 +50,17 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
5350
p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint))
5451
}
5552

56-
// Similar to the above Filter over Aggregate
5753
// LeftSemi/LeftAnti over Aggregate
5854
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
59-
if agg.aggregateExpressions.forall(_.deterministic)
60-
&& agg.groupingExpressions.nonEmpty =>
55+
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty =>
6156
if (joinCond.isEmpty) {
6257
// No join condition, just push down Join below Aggregate
6358
agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint))
6459
} else {
6560
val aliasMap = PushDownPredicate.getAliasMap(agg)
6661

67-
// For each join condition, expand the alias and
68-
// check if the condition can be evaluated using
69-
// attributes produced by the aggregate operator's child operator.
62+
// For each join condition, expand the alias and check if the condition can be evaluated
63+
// using attributes produced by the aggregate operator's child operator.
7064
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
7165
val replaced = replaceAlias(cond, aliasMap)
7266
cond.references.nonEmpty &&
@@ -90,7 +84,6 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
9084
}
9185
}
9286

93-
// Similar to the above Filter over Window
9487
// LeftSemi/LeftAnti over Window
9588
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
9689
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
@@ -119,7 +112,6 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
119112
}
120113
}
121114

122-
// Similar to the above Filter over Union
123115
// LeftSemi/LeftAnti over Union
124116
case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
125117
if canPushThroughCondition(union.children, joinCond, rightOp) =>
@@ -148,10 +140,9 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
148140
}
149141
}
150142

151-
// Similar to the above Filter over UnaryNode
152143
// LeftSemi/LeftAnti over UnaryNode
153144
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
154-
if PushDownPredicate.canPushThrough(u) =>
145+
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
155146
pushDownJoin(join, u.child) { joinCond =>
156147
u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint)))
157148
}

0 commit comments

Comments
 (0)