Skip to content

Commit 488eda8

Browse files
committed
Code review
1 parent ae5f6ee commit 488eda8

File tree

3 files changed

+218
-202
lines changed

3 files changed

+218
-202
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,10 @@ object ScalarSubquery {
270270

271271
def hasScalarSubquery(e: Expression): Boolean = {
272272
e.find {
273-
case s: ScalarSubquery => true
273+
case _: ScalarSubquery => true
274274
case _ => false
275275
}.isDefined
276276
}
277-
278-
def hasScalarSubquery(e: Seq[Expression]): Boolean = {
279-
e.find(hasScalarSubquery(_)).isDefined
280-
}
281277
}
282278

283279
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 20 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,24 +1017,13 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
10171017
// This also applies to Aggregate.
10181018
case Filter(condition, project @ Project(fields, grandChild))
10191019
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
1020-
1021-
// Create a map of Aliases to their values from the child projection.
1022-
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1023-
val aliasMap = AttributeMap(fields.collect {
1024-
case a: Alias => (a.toAttribute, a.child)
1025-
})
1026-
1020+
val aliasMap = getAliasMap(project)
10271021
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
10281022

10291023
case filter @ Filter(condition, aggregate: Aggregate)
10301024
if aggregate.aggregateExpressions.forall(_.deterministic)
10311025
&& aggregate.groupingExpressions.nonEmpty =>
1032-
// Find all the aliased expressions in the aggregate list that don't include any actual
1033-
// AggregateExpression, and create a map from the alias to the expression
1034-
val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
1035-
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1036-
(a.toAttribute, a.child)
1037-
})
1026+
val aliasMap = getAliasMap(aggregate)
10381027

10391028
// For each filter, expand the alias and check if the filter can be evaluated using
10401029
// attributes produced by the aggregate operator's child operator.
@@ -1132,6 +1121,24 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
11321121
}
11331122
}
11341123

1124+
1125+
def getAliasMap(plan: LogicalPlan): AttributeMap[Expression] = {
1126+
val aliasMap = plan match {
1127+
case p: Project =>
1128+
// Create a map of Aliases to their values from the child projection.
1129+
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1130+
p.projectList.collect { case a: Alias => (a.toAttribute, a.child) }
1131+
case a: Aggregate =>
1132+
// Find all the aliased expressions in the aggregate list that don't include any actual
1133+
// AggregateExpression, and create a map from the alias to the expression
1134+
a.aggregateExpressions.collect {
1135+
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1136+
(a.toAttribute, a.child)
1137+
}
1138+
}
1139+
AttributeMap(aliasMap)
1140+
}
1141+
11351142
def canPushThrough(p: UnaryNode): Boolean = p match {
11361143
// Note that some operators (e.g. project, aggregate, union) are being handled separately
11371144
// (earlier in this rule).
@@ -1189,190 +1196,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
11891196
}
11901197
}
11911198

1192-
object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
1193-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1194-
// Similar to the above Filter over Project
1195-
// LeftSemi/LeftAnti over Project
1196-
case join @ Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
1197-
if pList.forall(_.deterministic) && !ScalarSubquery.hasScalarSubquery(pList) &&
1198-
canPushThroughCondition(Seq(gChild), joinCond, rightOp) =>
1199-
if (joinCond.isEmpty) {
1200-
// No join condition, just push down the Join below Project
1201-
Project(pList, Join(gChild, rightOp, joinType, joinCond, hint))
1202-
} else {
1203-
// Create a map of Aliases to their values from the child projection.
1204-
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1205-
val aliasMap = AttributeMap(pList.collect {
1206-
case a: Alias => (a.toAttribute, a.child)
1207-
})
1208-
val newJoinCond = if (aliasMap.nonEmpty) {
1209-
Option(replaceAlias(joinCond.get, aliasMap))
1210-
} else {
1211-
joinCond
1212-
}
1213-
Project(pList, Join(gChild, rightOp, joinType, newJoinCond, hint))
1214-
}
1215-
1216-
// Similar to the above Filter over Aggregate
1217-
// LeftSemi/LeftAnti over Aggregate
1218-
case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
1219-
if aggregate.aggregateExpressions.forall(_.deterministic)
1220-
&& aggregate.groupingExpressions.nonEmpty =>
1221-
if (joinCond.isEmpty) {
1222-
// No join condition, just push down Join below Aggregate
1223-
aggregate.copy(child = Join(aggregate.child, rightOp, joinType, joinCond, hint))
1224-
} else {
1225-
// Find all the aliased expressions in the aggregate list that don't include any actual
1226-
// AggregateExpression, and create a map from the alias to the expression
1227-
val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
1228-
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1229-
(a.toAttribute, a.child)
1230-
})
1231-
1232-
// For each join condition, expand the alias and
1233-
// check if the condition can be evaluated using
1234-
// attributes produced by the aggregate operator's child operator.
1235-
1236-
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
1237-
val replaced = replaceAlias(cond, aliasMap)
1238-
cond.references.nonEmpty &&
1239-
replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet)
1240-
}
1241-
1242-
// Check if the remaining predicates do not contain columns from subquery
1243-
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
1244-
1245-
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
1246-
val pushDownPredicate = pushDown.reduce(And)
1247-
val replaced = replaceAlias(pushDownPredicate, aliasMap)
1248-
val newAggregate = aggregate.copy(child =
1249-
Join(aggregate.child, rightOp, joinType, Option(replaced), hint))
1250-
// If there is no more filter to stay up, just return the Aggregate over Join.
1251-
// Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)".
1252-
if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
1253-
} else {
1254-
// The join condition is not a subset of the Aggregate's GROUP BY columns,
1255-
// no push down.
1256-
join
1257-
}
1258-
}
1259-
1260-
// Similar to the above Filter over Window
1261-
// LeftSemi/LeftAnti over Window
1262-
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
1263-
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
1264-
if (joinCond.isEmpty) {
1265-
// No join condition, just push down Join below Window
1266-
w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint))
1267-
} else {
1268-
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++
1269-
rightOp.outputSet
1270-
1271-
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
1272-
cond.references.subsetOf(partitionAttrs)
1273-
}
1274-
1275-
// Check if the remaining predicates do not contain columns from subquery
1276-
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
1277-
1278-
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
1279-
val pushDownPredicate = pushDown.reduce(And)
1280-
val newPlan =
1281-
w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate), hint))
1282-
if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan)
1283-
} else {
1284-
// The join condition is not a subset of the Window's PARTITION BY clause,
1285-
// no push down.
1286-
join
1287-
}
1288-
}
1289-
1290-
// Similar to the above Filter over Union
1291-
// LeftSemi/LeftAnti over Union
1292-
case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
1293-
if canPushThroughCondition(union.children, joinCond, rightOp) =>
1294-
if (joinCond.isEmpty) {
1295-
// Push down the Join below Union
1296-
val newGrandChildren = union.children.map { grandchild =>
1297-
Join(grandchild, rightOp, joinType, joinCond, hint)
1298-
}
1299-
union.withNewChildren(newGrandChildren)
1300-
} else {
1301-
val pushDown = splitConjunctivePredicates(joinCond.get)
1302-
1303-
if (pushDown.nonEmpty) {
1304-
val pushDownCond = pushDown.reduceLeft(And)
1305-
val output = union.output
1306-
val newGrandChildren = union.children.map { grandchild =>
1307-
val newCond = pushDownCond transform {
1308-
case e if output.exists(_.semanticEquals(e)) =>
1309-
grandchild.output(output.indexWhere(_.semanticEquals(e)))
1310-
}
1311-
assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet))
1312-
Join(grandchild, rightOp, joinType, Option(newCond), hint)
1313-
}
1314-
union.withNewChildren(newGrandChildren)
1315-
} else {
1316-
// Nothing to push down
1317-
join
1318-
}
1319-
}
1320-
1321-
// Similar to the above Filter over UnaryNode
1322-
// LeftSemi/LeftAnti over UnaryNode
1323-
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
1324-
if PushDownPredicate.canPushThrough(u) =>
1325-
pushDownJoin(join, u.child) { joinCond =>
1326-
u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint)))
1327-
}
1328-
}
1329-
1330-
/**
1331-
* Check if we can safely push a join through a project or union by making sure that predicate
1332-
* subqueries in the condition do not contain the same attributes as the plan they are moved
1333-
* into. This can happen when the plan and predicate subquery have the same source.
1334-
*/
1335-
private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression],
1336-
rightOp: LogicalPlan): Boolean = {
1337-
val attributes = AttributeSet(plans.flatMap (_.output))
1338-
if (condition.isDefined) {
1339-
val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes)
1340-
matched.isEmpty
1341-
} else true
1342-
}
1343-
1344-
1345-
private def pushDownJoin(
1346-
join: Join,
1347-
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
1348-
// Only push down the join when join condition deterministic and all the referenced attributes
1349-
// come from childen of left and right legs of join.
1350-
val (candidates, containingNonDeterministic) = if (join.condition.isDefined) {
1351-
splitConjunctivePredicates(join.condition.get).partition(_.deterministic)
1352-
} else {
1353-
(Nil, Nil)
1354-
}
1355-
1356-
val (pushDown, rest) = candidates.partition { cond =>
1357-
cond.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)
1358-
}
1359-
1360-
val stayUp = rest ++ containingNonDeterministic
1361-
1362-
if (pushDown.nonEmpty) {
1363-
val newChild = insertFilter(pushDown.reduceLeft(And))
1364-
if (stayUp.nonEmpty) {
1365-
Filter(stayUp.reduceLeft(And), newChild)
1366-
} else {
1367-
newChild
1368-
}
1369-
} else {
1370-
join
1371-
}
1372-
}
1373-
1374-
}
1375-
13761199
/**
13771200
* Pushes down [[Filter]] operators where the `condition` can be
13781201
* evaluated using only the attributes of the left or right side of a join. Other

0 commit comments

Comments
 (0)