Skip to content

Commit 73ee739

Browse files
committed
[SPARK-18609][SPARK-18841][SQL] Fix redundant Alias removal in the optimizer
## What changes were proposed in this pull request? The optimizer tries to remove redundant alias only projections from the query plan using the `RemoveAliasOnlyProject` rule. The current rule identifies removes such a project and rewrites the project's attributes in the **entire** tree. This causes problems when parts of the tree are duplicated (for instance a self join on a temporary view/CTE) and the duplicated part contains the alias only project, in this case the rewrite will break the tree. This PR fixes these problems by using a blacklist for attributes that are not to be moved, and by making sure that attribute remapping is only done for the parent tree, and not for unrelated parts of the query plan. The current tree transformation infrastructure works very well if the transformation at hand requires little or a global contextual information. In this case we need to know both the attributes that were not to be moved, and we also needed to know which child attributes were modified. This cannot be done easily using the current infrastructure, and solutions typically involves transversing the query plan multiple times (which is super slow). I have moved around some code in `TreeNode`, `QueryPlan` and `LogicalPlan`to make this much more straightforward; this basically allows you to manually traverse the tree. This PR subsumes the following PRs by windpiger: Closes #16267 Closes #16255 ## How was this patch tested? I have added unit tests to `RemoveRedundantAliasAndProjectSuite` and I have added integration tests to the `SQLQueryTestSuite.union` and `SQLQueryTestSuite.cte` test cases. Author: Herman van Hovell <[email protected]> Closes #16757 from hvanhovell/SPARK-18609.
1 parent b7277e0 commit 73ee739

File tree

9 files changed

+302
-115
lines changed

9 files changed

+302
-115
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
110110
SimplifyCaseConversionExpressions,
111111
RewriteCorrelatedScalarSubquery,
112112
EliminateSerialization,
113-
RemoveAliasOnlyProject,
113+
RemoveRedundantAliases,
114+
RemoveRedundantProject,
114115
SimplifyCreateStructOps,
115116
SimplifyCreateArrayOps,
116117
SimplifyCreateMapOps) ::
@@ -157,56 +158,98 @@ class SimpleTestOptimizer extends Optimizer(
157158
new SimpleCatalystConf(caseSensitiveAnalysis = true))
158159

159160
/**
160-
* Removes the Project only conducting Alias of its child node.
161-
* It is created mainly for removing extra Project added in EliminateSerialization rule,
162-
* but can also benefit other operators.
161+
* Remove redundant aliases from a query plan. A redundant alias is an alias that does not change
162+
* the name or metadata of a column, and does not deduplicate it.
163163
*/
164-
object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
164+
object RemoveRedundantAliases extends Rule[LogicalPlan] {
165+
165166
/**
166-
* Returns true if the project list is semantically same as child output, after strip alias on
167-
* attribute.
167+
* Create an attribute mapping from the old to the new attributes. This function will only
168+
* return the attribute pairs that have changed.
168169
*/
169-
private def isAliasOnly(
170-
projectList: Seq[NamedExpression],
171-
childOutput: Seq[Attribute]): Boolean = {
172-
if (projectList.length != childOutput.length) {
173-
false
174-
} else {
175-
stripAliasOnAttribute(projectList).zip(childOutput).forall {
176-
case (a: Attribute, o) if a semanticEquals o => true
177-
case _ => false
178-
}
170+
private def createAttributeMapping(current: LogicalPlan, next: LogicalPlan)
171+
: Seq[(Attribute, Attribute)] = {
172+
current.output.zip(next.output).filterNot {
173+
case (a1, a2) => a1.semanticEquals(a2)
179174
}
180175
}
181176

182-
private def stripAliasOnAttribute(projectList: Seq[NamedExpression]) = {
183-
projectList.map {
184-
// Alias with metadata can not be stripped, or the metadata will be lost.
185-
// If the alias name is different from attribute name, we can't strip it either, or we may
186-
// accidentally change the output schema name of the root plan.
187-
case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name =>
188-
attr
189-
case other => other
190-
}
177+
/**
178+
* Remove the top-level alias from an expression when it is redundant.
179+
*/
180+
private def removeRedundantAlias(e: Expression, blacklist: AttributeSet): Expression = e match {
181+
// Alias with metadata can not be stripped, or the metadata will be lost.
182+
// If the alias name is different from attribute name, we can't strip it either, or we
183+
// may accidentally change the output schema name of the root plan.
184+
case a @ Alias(attr: Attribute, name)
185+
if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) =>
186+
attr
187+
case a => a
191188
}
192189

193-
def apply(plan: LogicalPlan): LogicalPlan = {
194-
val aliasOnlyProject = plan.collectFirst {
195-
case p @ Project(pList, child) if isAliasOnly(pList, child.output) => p
196-
}
190+
/**
191+
* Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to
192+
* prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self)
193+
* join.
194+
*/
195+
private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = {
196+
plan match {
197+
// A join has to be treated differently, because the left and the right side of the join are
198+
// not allowed to use the same attributes. We use a blacklist to prevent us from creating a
199+
// situation in which this happens; the rule will only remove an alias if its child
200+
// attribute is not on the black list.
201+
case Join(left, right, joinType, condition) =>
202+
val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet)
203+
val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet)
204+
val mapping = AttributeMap(
205+
createAttributeMapping(left, newLeft) ++
206+
createAttributeMapping(right, newRight))
207+
val newCondition = condition.map(_.transform {
208+
case a: Attribute => mapping.getOrElse(a, a)
209+
})
210+
Join(newLeft, newRight, joinType, newCondition)
211+
212+
case _ =>
213+
// Remove redundant aliases in the subtree(s).
214+
val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)]
215+
val newNode = plan.mapChildren { child =>
216+
val newChild = removeRedundantAliases(child, blacklist)
217+
currentNextAttrPairs ++= createAttributeMapping(child, newChild)
218+
newChild
219+
}
197220

198-
aliasOnlyProject.map { case proj =>
199-
val attributesToReplace = proj.output.zip(proj.child.output).filterNot {
200-
case (a1, a2) => a1 semanticEquals a2
201-
}
202-
val attrMap = AttributeMap(attributesToReplace)
203-
plan transform {
204-
case plan: Project if plan eq proj => plan.child
205-
case plan => plan transformExpressions {
206-
case a: Attribute if attrMap.contains(a) => attrMap(a)
221+
// Create the attribute mapping. Note that the currentNextAttrPairs can contain duplicate
222+
// keys in case of Union (this is caused by the PushProjectionThroughUnion rule); in this
223+
// case we use the the first mapping (which should be provided by the first child).
224+
val mapping = AttributeMap(currentNextAttrPairs)
225+
226+
// Create a an expression cleaning function for nodes that can actually produce redundant
227+
// aliases, use identity otherwise.
228+
val clean: Expression => Expression = plan match {
229+
case _: Project => removeRedundantAlias(_, blacklist)
230+
case _: Aggregate => removeRedundantAlias(_, blacklist)
231+
case _: Window => removeRedundantAlias(_, blacklist)
232+
case _ => identity[Expression]
207233
}
208-
}
209-
}.getOrElse(plan)
234+
235+
// Transform the expressions.
236+
newNode.mapExpressions { expr =>
237+
clean(expr.transform {
238+
case a: Attribute => mapping.getOrElse(a, a)
239+
})
240+
}
241+
}
242+
}
243+
244+
def apply(plan: LogicalPlan): LogicalPlan = removeRedundantAliases(plan, AttributeSet.empty)
245+
}
246+
247+
/**
248+
* Remove projections from the query plan that do not make any modifications.
249+
*/
250+
object RemoveRedundantProject extends Rule[LogicalPlan] {
251+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
252+
case p @ Project(_, child) if p.output == child.output => child
210253
}
211254
}
212255

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -242,31 +242,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
242242
* @param rule the rule to be applied to every expression in this operator.
243243
*/
244244
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
245-
var changed = false
246-
247-
@inline def transformExpressionDown(e: Expression): Expression = {
248-
val newE = e.transformDown(rule)
249-
if (newE.fastEquals(e)) {
250-
e
251-
} else {
252-
changed = true
253-
newE
254-
}
255-
}
256-
257-
def recursiveTransform(arg: Any): AnyRef = arg match {
258-
case e: Expression => transformExpressionDown(e)
259-
case Some(e: Expression) => Some(transformExpressionDown(e))
260-
case m: Map[_, _] => m
261-
case d: DataType => d // Avoid unpacking Structs
262-
case seq: Traversable[_] => seq.map(recursiveTransform)
263-
case other: AnyRef => other
264-
case null => null
265-
}
266-
267-
val newArgs = mapProductIterator(recursiveTransform)
268-
269-
if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
245+
mapExpressions(_.transformDown(rule))
270246
}
271247

272248
/**
@@ -276,10 +252,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
276252
* @return
277253
*/
278254
def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
255+
mapExpressions(_.transformUp(rule))
256+
}
257+
258+
/**
259+
* Apply a map function to each expression present in this query operator, and return a new
260+
* query operator based on the mapped expressions.
261+
*/
262+
def mapExpressions(f: Expression => Expression): this.type = {
279263
var changed = false
280264

281-
@inline def transformExpressionUp(e: Expression): Expression = {
282-
val newE = e.transformUp(rule)
265+
@inline def transformExpression(e: Expression): Expression = {
266+
val newE = f(e)
283267
if (newE.fastEquals(e)) {
284268
e
285269
} else {
@@ -289,8 +273,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
289273
}
290274

291275
def recursiveTransform(arg: Any): AnyRef = arg match {
292-
case e: Expression => transformExpressionUp(e)
293-
case Some(e: Expression) => Some(transformExpressionUp(e))
276+
case e: Expression => transformExpression(e)
277+
case Some(e: Expression) => Some(transformExpression(e))
294278
case m: Map[_, _] => m
295279
case d: DataType => d // Avoid unpacking Structs
296280
case seq: Traversable[_] => seq.map(recursiveTransform)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
5656
*/
5757
def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
5858
if (!analyzed) {
59-
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r))
59+
val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
6060
if (this fastEquals afterRuleOnChildren) {
6161
CurrentOrigin.withOrigin(origin) {
6262
rule.applyOrElse(this, identity[LogicalPlan])

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
190190
arr
191191
}
192192

193-
/**
194-
* Returns a copy of this node where `f` has been applied to all the nodes children.
195-
*/
196-
def mapChildren(f: BaseType => BaseType): BaseType = {
197-
var changed = false
198-
val newArgs = mapProductIterator {
199-
case arg: TreeNode[_] if containsChild(arg) =>
200-
val newChild = f(arg.asInstanceOf[BaseType])
201-
if (newChild fastEquals arg) {
202-
arg
203-
} else {
204-
changed = true
205-
newChild
206-
}
207-
case nonChild: AnyRef => nonChild
208-
case null => null
209-
}
210-
if (changed) makeCopy(newArgs) else this
211-
}
212-
213193
/**
214194
* Returns a copy of this node with the children replaced.
215195
* TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
@@ -289,9 +269,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
289269

290270
// Check if unchanged and then possibly return old copy to avoid gc churn.
291271
if (this fastEquals afterRule) {
292-
transformChildren(rule, (t, r) => t.transformDown(r))
272+
mapChildren(_.transformDown(rule))
293273
} else {
294-
afterRule.transformChildren(rule, (t, r) => t.transformDown(r))
274+
afterRule.mapChildren(_.transformDown(rule))
295275
}
296276
}
297277

@@ -303,7 +283,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
303283
* @param rule the function use to transform this nodes children
304284
*/
305285
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
306-
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
286+
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
307287
if (this fastEquals afterRuleOnChildren) {
308288
CurrentOrigin.withOrigin(origin) {
309289
rule.applyOrElse(this, identity[BaseType])
@@ -316,26 +296,22 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
316296
}
317297

318298
/**
319-
* Returns a copy of this node where `rule` has been recursively applied to all the children of
320-
* this node. When `rule` does not apply to a given node it is left unchanged.
321-
* @param rule the function used to transform this nodes children
299+
* Returns a copy of this node where `f` has been applied to all the nodes children.
322300
*/
323-
protected def transformChildren(
324-
rule: PartialFunction[BaseType, BaseType],
325-
nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
301+
def mapChildren(f: BaseType => BaseType): BaseType = {
326302
if (children.nonEmpty) {
327303
var changed = false
328304
val newArgs = mapProductIterator {
329305
case arg: TreeNode[_] if containsChild(arg) =>
330-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
306+
val newChild = f(arg.asInstanceOf[BaseType])
331307
if (!(newChild fastEquals arg)) {
332308
changed = true
333309
newChild
334310
} else {
335311
arg
336312
}
337313
case Some(arg: TreeNode[_]) if containsChild(arg) =>
338-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
314+
val newChild = f(arg.asInstanceOf[BaseType])
339315
if (!(newChild fastEquals arg)) {
340316
changed = true
341317
Some(newChild)
@@ -344,7 +320,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
344320
}
345321
case m: Map[_, _] => m.mapValues {
346322
case arg: TreeNode[_] if containsChild(arg) =>
347-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
323+
val newChild = f(arg.asInstanceOf[BaseType])
348324
if (!(newChild fastEquals arg)) {
349325
changed = true
350326
newChild
@@ -356,16 +332,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
356332
case d: DataType => d // Avoid unpacking Structs
357333
case args: Traversable[_] => args.map {
358334
case arg: TreeNode[_] if containsChild(arg) =>
359-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
335+
val newChild = f(arg.asInstanceOf[BaseType])
360336
if (!(newChild fastEquals arg)) {
361337
changed = true
362338
newChild
363339
} else {
364340
arg
365341
}
366342
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
367-
val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
368-
val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
343+
val newChild1 = f(arg1.asInstanceOf[BaseType])
344+
val newChild2 = f(arg2.asInstanceOf[BaseType])
369345
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
370346
changed = true
371347
(newChild1, newChild2)

0 commit comments

Comments
 (0)