Skip to content

Commit ef64abf

Browse files
author
Tanel Kiis
committed
Handle aliases
1 parent 4ce0644 commit ef64abf

File tree

4 files changed

+102
-40
lines changed

4 files changed

+102
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr
9797
}
9898
}
9999

100-
trait PredicateHelper extends Logging {
100+
trait PredicateHelper extends Logging with AliasHelper {
101101
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
102102
condition match {
103103
case And(cond1, cond2) =>
@@ -150,18 +150,6 @@ trait PredicateHelper extends Logging {
150150
}
151151
}
152152

153-
// Substitute any known alias from a map.
154-
protected def replaceAlias(
155-
condition: Expression,
156-
aliases: AttributeMap[Expression]): Expression = {
157-
// Use transformUp to prevent infinite recursion when the replacement expression
158-
// redefines the same ExprId,
159-
condition.transformUp {
160-
case a: Attribute =>
161-
aliases.getOrElse(a, a)
162-
}
163-
}
164-
165153
/**
166154
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
167155
* can be used to determine when it is acceptable to move expression evaluation within a query
@@ -249,6 +237,41 @@ trait PredicateHelper extends Logging {
249237
}
250238
}
251239

240+
/**
241+
* Helper methods for collecting and replacing aliases.
242+
*/
243+
trait AliasHelper {
244+
245+
protected def getAliasMap(plan: Project): AttributeMap[Expression] = {
246+
// Create a map of Aliases to their values from the child projection.
247+
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
248+
AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) })
249+
}
250+
251+
protected def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
252+
// Find all the aliased expressions in the aggregate list that don't include any actual
253+
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
254+
val aliasMap = plan.aggregateExpressions.collect {
255+
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
256+
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
257+
(a.toAttribute, a.child)
258+
}
259+
AttributeMap(aliasMap)
260+
}
261+
262+
// Substitute any known alias from a map.
263+
protected def replaceAlias(
264+
condition: Expression,
265+
aliases: AttributeMap[Expression]): Expression = {
266+
// Use transformUp to prevent infinite recursion when the replacement expression
267+
// redefines the same ExprId,
268+
condition.transformUp {
269+
case a: Attribute =>
270+
aliases.getOrElse(a, a)
271+
}
272+
}
273+
}
274+
252275
@ExpressionDescription(
253276
usage = "_FUNC_ expr - Logical not.",
254277
examples = """

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -482,18 +482,51 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
482482
* Remove redundant aggregates from a query plan. A redundant aggregate is an aggregate whose
483483
* only goal is to keep distinct values, while its parent aggregate would ignore duplicate values.
484484
*/
485-
object RemoveRedundantAggregates extends Rule[LogicalPlan] {
485+
object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
486486
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
487487
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
488-
upper.copy(child = lower.child)
488+
val aliasMap = getAliasMap(lower)
489+
upper.copy(
490+
child = lower.child,
491+
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
492+
aggregateExpressions = upper.aggregateExpressions.map(
493+
replaceAliasButKeepOuter(_, aliasMap))
494+
)
489495
}
490496

491497
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
492-
val upperReferencesOnlyGrouping = upper.references
493-
.subsetOf(AttributeSet(lower.groupingExpressions))
498+
val isDeterministic = upper.aggregateExpressions.forall(_.deterministic) &&
499+
lower.aggregateExpressions.forall(_.deterministic)
500+
501+
val upperReferencesOnlyGrouping = upper.references.subsetOf(AttributeSet(
502+
lower.aggregateExpressions.filter(!isAggregate(_)).map(_.toAttribute)))
503+
494504
val upperHasNoAggregateExpressions = upper.aggregateExpressions
495-
.forall(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
496-
upperReferencesOnlyGrouping && upperHasNoAggregateExpressions
505+
.forall(_.find(isAggregate).isEmpty)
506+
507+
isDeterministic && upperReferencesOnlyGrouping && upperHasNoAggregateExpressions
508+
}
509+
510+
private def isAggregate(expr: Expression): Boolean = {
511+
expr.find(e => e.isInstanceOf[AggregateExpression] ||
512+
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
513+
}
514+
515+
/**
516+
* Replace all attributes, that reference an alias, with the aliased expression,
517+
* but keep the name of the name of the outmost attribute.
518+
*/
519+
private def replaceAliasButKeepOuter(
520+
expr: NamedExpression,
521+
aliasMap: AttributeMap[Expression]): NamedExpression = {
522+
523+
val replaced = expr match {
524+
case a: Attribute if aliasMap.contains(a) =>
525+
Alias(replaceAlias(a, aliasMap), a.name)(a.exprId, a.qualifier)
526+
case _ => replaceAlias(expr, aliasMap)
527+
}
528+
529+
replaced.asInstanceOf[NamedExpression]
497530
}
498531
}
499532

@@ -1258,23 +1291,6 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
12581291
}
12591292
}
12601293

1261-
def getAliasMap(plan: Project): AttributeMap[Expression] = {
1262-
// Create a map of Aliases to their values from the child projection.
1263-
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1264-
AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) })
1265-
}
1266-
1267-
def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
1268-
// Find all the aliased expressions in the aggregate list that don't include any actual
1269-
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
1270-
val aliasMap = plan.aggregateExpressions.collect {
1271-
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
1272-
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
1273-
(a.toAttribute, a.child)
1274-
}
1275-
AttributeMap(aliasMap)
1276-
}
1277-
12781294
def canPushThrough(p: UnaryNode): Boolean = p match {
12791295
// Note that some operators (e.g. project, aggregate, union) are being handled separately
12801296
// (earlier in this rule).

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
4242
// No join condition, just push down the Join below Project
4343
p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
4444
} else {
45-
val aliasMap = PushPredicateThroughNonJoin.getAliasMap(p)
45+
val aliasMap = getAliasMap(p)
4646
val newJoinCond = if (aliasMap.nonEmpty) {
4747
Option(replaceAlias(joinCond.get, aliasMap))
4848
} else {
@@ -55,7 +55,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
5555
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
5656
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
5757
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
58-
val aliasMap = PushPredicateThroughNonJoin.getAliasMap(agg)
58+
val aliasMap = getAliasMap(agg)
5959
val canPushDownPredicate = (predicate: Expression) => {
6060
val replaced = replaceAlias(predicate, aliasMap)
6161
predicate.references.nonEmpty &&

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,37 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
7070
comparePlans(optimized, expected)
7171
}
7272

73-
test("Keep non-redundant aggregate") {
73+
test("Remove redundant aggregate with aliases") {
7474
val relation = LocalRelation('a.int, 'b.int)
7575
val query = relation
76-
.groupBy('a)('a, first('b) as 'b)
76+
.groupBy('a + 'b)(('a + 'b) as 'c, count('b))
77+
.groupBy('c)('c)
78+
.analyze
79+
val expected = relation
80+
.groupBy('a + 'b)(('a + 'b) as 'c)
81+
.analyze
82+
val optimized = Optimize.execute(query)
83+
comparePlans(optimized, expected)
84+
}
85+
86+
test("Keep non-redundant aggregate - upper has agg expression") {
87+
val relation = LocalRelation('a.int, 'b.int)
88+
val query = relation
89+
.groupBy('a, 'b)('a, 'b)
7790
// The count would change if we remove the first aggregate
7891
.groupBy('a)('a, count('b))
7992
.analyze
8093
val optimized = Optimize.execute(query)
8194
comparePlans(optimized, query)
8295
}
96+
97+
test("Keep non-redundant aggregate - upper references non-grouping") {
98+
val relation = LocalRelation('a.int, 'b.int)
99+
val query = relation
100+
.groupBy('a)('a, count('b) as 'c)
101+
.groupBy('c)('c)
102+
.analyze
103+
val optimized = Optimize.execute(query)
104+
comparePlans(optimized, query)
105+
}
83106
}

0 commit comments

Comments
 (0)