@@ -137,56 +137,59 @@ object Analyzer {
137137 */
138138 def rewritePlan (plan : LogicalPlan , rewritePlanMap : Map [LogicalPlan , LogicalPlan ])
139139 : (LogicalPlan , Seq [(Attribute , Attribute )]) = {
140- if (plan.resolved) {
141- val attrMapping = new mutable. ArrayBuffer [( Attribute , Attribute )]()
142- val newChildren = plan.children.map { child =>
143- // If not, we'd rewrite child plan recursively until we find the
144- // conflict node or reach the leaf node.
145- val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
146- attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
147- // `attrMapping` is not only used to replace the attributes of the current `plan`,
148- // but also to be propagated to the parent plans of the current `plan`. Therefore,
149- // the `oldAttr` must be part of either `plan.references ` (so that it can be used to
150- // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
151- // used by those parent plans).
140+ val attrMapping = new mutable. ArrayBuffer [( Attribute , Attribute )]()
141+ val newChildren = plan.children.map { child =>
142+ // If not, we'd rewrite child plan recursively until we find the
143+ // conflict node or reach the leaf node.
144+ val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
145+ attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
146+ // ` attrMapping` is not only used to replace the attributes of the current `plan`,
147+ // but also to be propagated to the parent plans of the current `plan`. Therefore ,
148+ // the `oldAttr` must be part of either `plan.references` (so that it can be used to
149+ // replace attributes of the current `plan`) or `plan.outputSet ` (so that it can be
150+ // used by those parent plans).
151+ if (plan.resolved) {
152152 (plan.outputSet ++ plan.references).contains(oldAttr)
153+ } else {
154+ plan.references.filter(_.resolved).contains(oldAttr)
153155 }
154- newChild
155156 }
157+ newChild
158+ }
156159
157- val newPlan = if (rewritePlanMap.contains(plan)) {
158- rewritePlanMap(plan).withNewChildren(newChildren)
159- } else {
160- plan.withNewChildren(newChildren)
161- }
160+ val newPlan = if (rewritePlanMap.contains(plan)) {
161+ rewritePlanMap(plan).withNewChildren(newChildren)
162+ } else {
163+ plan.withNewChildren(newChildren)
164+ }
162165
163- assert(! attrMapping.groupBy(_._1.exprId)
164- .exists(_._2.map(_._2.exprId).distinct.length > 1 ),
165- " Found duplicate rewrite attributes" )
166+ assert(! attrMapping.groupBy(_._1.exprId)
167+ .exists(_._2.map(_._2.exprId).distinct.length > 1 ),
168+ " Found duplicate rewrite attributes" )
166169
167- val attributeRewrites = AttributeMap (attrMapping)
168- // Using attrMapping from the children plans to rewrite their parent node.
169- // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
170- val p = newPlan.transformExpressions {
171- case a : Attribute =>
172- updateAttr(a, attributeRewrites)
173- case s : SubqueryExpression =>
174- s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
175- }
170+ val attributeRewrites = AttributeMap (attrMapping)
171+ // Using attrMapping from the children plans to rewrite their parent node.
172+ // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
173+ val p = newPlan.transformExpressions {
174+ case a : Attribute =>
175+ updateAttr(a, attributeRewrites)
176+ case s : SubqueryExpression =>
177+ s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
178+ }
179+ if (plan.resolved) {
176180 attrMapping ++= plan.output.zip(p.output)
177181 .filter { case (a1, a2) => a1.exprId != a2.exprId }
178- p -> attrMapping
179- } else {
180- // Just passes through unresolved nodes
181- plan.mapChildren {
182- rewritePlan(_, rewritePlanMap)._1
183- } -> Nil
184182 }
183+ p -> attrMapping
185184 }
186185
187186 private def updateAttr (attr : Attribute , attrMap : AttributeMap [Attribute ]): Attribute = {
188- val exprId = attrMap.getOrElse(attr, attr).exprId
189- attr.withExprId(exprId)
187+ if (attr.resolved) {
188+ val exprId = attrMap.getOrElse(attr, attr).exprId
189+ attr.withExprId(exprId)
190+ } else {
191+ attr
192+ }
190193 }
191194
192195 /**
@@ -2694,13 +2697,13 @@ class Analyzer(
26942697 case ne : NamedExpression =>
26952698 // If a named expression is not in regularExpressions, add it to
26962699 // extractedExprBuffer and replace it with an AttributeReference.
2700+ val attr = ne.toAttribute
26972701 val missingExpr =
2698- AttributeSet (Seq (expr )) -- (regularExpressions ++ extractedExprBuffer)
2702+ AttributeSet (Seq (attr )) -- (regularExpressions ++ extractedExprBuffer)
26992703 if (missingExpr.nonEmpty) {
27002704 extractedExprBuffer += ne
27012705 }
2702- // alias will be cleaned in the rule CleanupAliases
2703- ne
2706+ attr
27042707 case e : Expression if e.foldable =>
27052708 e // No need to create an attribute reference if it will be evaluated as a Literal.
27062709 case e : Expression =>
@@ -2831,7 +2834,7 @@ class Analyzer(
28312834 val windowOps =
28322835 groupedWindowExpressions.foldLeft(child) {
28332836 case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
2834- Window (windowExpressions.toSeq , partitionSpec, orderSpec, last)
2837+ Window (windowExpressions, partitionSpec, orderSpec, last)
28352838 }
28362839
28372840 // Finally, we create a Project to output windowOps's output
@@ -2853,8 +2856,8 @@ class Analyzer(
28532856 // a resolved Aggregate will not have Window Functions.
28542857 case f @ UnresolvedHaving (condition, a @ Aggregate (groupingExprs, aggregateExprs, child))
28552858 if child.resolved &&
2856- hasWindowFunction(aggregateExprs) &&
2857- a.expressions.forall(_.resolved) =>
2859+ hasWindowFunction(aggregateExprs) &&
2860+ a.expressions.forall(_.resolved) =>
28582861 val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
28592862 // Create an Aggregate operator to evaluate aggregation functions.
28602863 val withAggregate = Aggregate (groupingExprs, aggregateExpressions, child)
@@ -2871,7 +2874,7 @@ class Analyzer(
28712874 // Aggregate without Having clause.
28722875 case a @ Aggregate (groupingExprs, aggregateExprs, child)
28732876 if hasWindowFunction(aggregateExprs) &&
2874- a.expressions.forall(_.resolved) =>
2877+ a.expressions.forall(_.resolved) =>
28752878 val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
28762879 // Create an Aggregate operator to evaluate aggregation functions.
28772880 val withAggregate = Aggregate (groupingExprs, aggregateExpressions, child)
0 commit comments