Skip to content

Commit acb2b80

Browse files
committed
review
1 parent 0ca08ca commit acb2b80

File tree

5 files changed

+64
-57
lines changed

5 files changed

+64
-57
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -137,56 +137,59 @@ object Analyzer {
137137
*/
138138
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
139139
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
140-
if (plan.resolved) {
141-
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
142-
val newChildren = plan.children.map { child =>
143-
// If not, we'd rewrite child plan recursively until we find the
144-
// conflict node or reach the leaf node.
145-
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
146-
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
147-
// `attrMapping` is not only used to replace the attributes of the current `plan`,
148-
// but also to be propagated to the parent plans of the current `plan`. Therefore,
149-
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
150-
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
151-
// used by those parent plans).
140+
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
141+
val newChildren = plan.children.map { child =>
142+
// If not, we'd rewrite child plan recursively until we find the
143+
// conflict node or reach the leaf node.
144+
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
145+
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
146+
// `attrMapping` is not only used to replace the attributes of the current `plan`,
147+
// but also to be propagated to the parent plans of the current `plan`. Therefore,
148+
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
149+
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
150+
// used by those parent plans).
151+
if (plan.resolved) {
152152
(plan.outputSet ++ plan.references).contains(oldAttr)
153+
} else {
154+
plan.references.filter(_.resolved).contains(oldAttr)
153155
}
154-
newChild
155156
}
157+
newChild
158+
}
156159

157-
val newPlan = if (rewritePlanMap.contains(plan)) {
158-
rewritePlanMap(plan).withNewChildren(newChildren)
159-
} else {
160-
plan.withNewChildren(newChildren)
161-
}
160+
val newPlan = if (rewritePlanMap.contains(plan)) {
161+
rewritePlanMap(plan).withNewChildren(newChildren)
162+
} else {
163+
plan.withNewChildren(newChildren)
164+
}
162165

163-
assert(!attrMapping.groupBy(_._1.exprId)
164-
.exists(_._2.map(_._2.exprId).distinct.length > 1),
165-
"Found duplicate rewrite attributes")
166+
assert(!attrMapping.groupBy(_._1.exprId)
167+
.exists(_._2.map(_._2.exprId).distinct.length > 1),
168+
"Found duplicate rewrite attributes")
166169

167-
val attributeRewrites = AttributeMap(attrMapping)
168-
// Using attrMapping from the children plans to rewrite their parent node.
169-
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
170-
val p = newPlan.transformExpressions {
171-
case a: Attribute =>
172-
updateAttr(a, attributeRewrites)
173-
case s: SubqueryExpression =>
174-
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
175-
}
170+
val attributeRewrites = AttributeMap(attrMapping)
171+
// Using attrMapping from the children plans to rewrite their parent node.
172+
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
173+
val p = newPlan.transformExpressions {
174+
case a: Attribute =>
175+
updateAttr(a, attributeRewrites)
176+
case s: SubqueryExpression =>
177+
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
178+
}
179+
if (plan.resolved) {
176180
attrMapping ++= plan.output.zip(p.output)
177181
.filter { case (a1, a2) => a1.exprId != a2.exprId }
178-
p -> attrMapping
179-
} else {
180-
// Just passes through unresolved nodes
181-
plan.mapChildren {
182-
rewritePlan(_, rewritePlanMap)._1
183-
} -> Nil
184182
}
183+
p -> attrMapping
185184
}
186185

187186
private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
188-
val exprId = attrMap.getOrElse(attr, attr).exprId
189-
attr.withExprId(exprId)
187+
if (attr.resolved) {
188+
val exprId = attrMap.getOrElse(attr, attr).exprId
189+
attr.withExprId(exprId)
190+
} else {
191+
attr
192+
}
190193
}
191194

192195
/**
@@ -2694,13 +2697,13 @@ class Analyzer(
26942697
case ne: NamedExpression =>
26952698
// If a named expression is not in regularExpressions, add it to
26962699
// extractedExprBuffer and replace it with an AttributeReference.
2700+
val attr = ne.toAttribute
26972701
val missingExpr =
2698-
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
2702+
AttributeSet(Seq(attr)) -- (regularExpressions ++ extractedExprBuffer)
26992703
if (missingExpr.nonEmpty) {
27002704
extractedExprBuffer += ne
27012705
}
2702-
// alias will be cleaned in the rule CleanupAliases
2703-
ne
2706+
attr
27042707
case e: Expression if e.foldable =>
27052708
e // No need to create an attribute reference if it will be evaluated as a Literal.
27062709
case e: Expression =>
@@ -2831,7 +2834,7 @@ class Analyzer(
28312834
val windowOps =
28322835
groupedWindowExpressions.foldLeft(child) {
28332836
case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
2834-
Window(windowExpressions.toSeq, partitionSpec, orderSpec, last)
2837+
Window(windowExpressions, partitionSpec, orderSpec, last)
28352838
}
28362839

28372840
// Finally, we create a Project to output windowOps's output
@@ -2853,8 +2856,8 @@ class Analyzer(
28532856
// a resolved Aggregate will not have Window Functions.
28542857
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
28552858
if child.resolved &&
2856-
hasWindowFunction(aggregateExprs) &&
2857-
a.expressions.forall(_.resolved) =>
2859+
hasWindowFunction(aggregateExprs) &&
2860+
a.expressions.forall(_.resolved) =>
28582861
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
28592862
// Create an Aggregate operator to evaluate aggregation functions.
28602863
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
@@ -2871,7 +2874,7 @@ class Analyzer(
28712874
// Aggregate without Having clause.
28722875
case a @ Aggregate(groupingExprs, aggregateExprs, child)
28732876
if hasWindowFunction(aggregateExprs) &&
2874-
a.expressions.forall(_.resolved) =>
2877+
a.expressions.forall(_.resolved) =>
28752878
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
28762879
// Create an Aggregate operator to evaluate aggregation functions.
28772880
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
458458
sys.error(s"Unexpected operator in scalar subquery: $lp")
459459
}
460460

461-
val resultMap = evalPlan(plan)
461+
val resultMap = evalPlan(plan).mapValues { _.transform {
462+
case a: Alias => a.newInstance() // Assigns a new `ExprId`
463+
}
464+
}
462465

463466
// By convention, the scalar subquery result is the leftmost field.
464467
resultMap.get(plan.output.head.exprId) match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,16 @@ object LogicalPlanIntegrity {
236236
* with one of reference attributes, e.g., `a#1 + 1 AS a#1`.
237237
*/
238238
def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = {
239-
plan.map { p =>
240-
p.expressions.filter(_.resolved).forall { e =>
241-
val namedExprs = e.collect {
242-
case ne: NamedExpression if !ne.isInstanceOf[LeafExpression] => ne
239+
plan.collect { case p if p.resolved =>
240+
val inputExprIds = p.inputSet.filter(_.resolved).map(_.exprId).toSet
241+
val newExprIds = p.expressions.filter(_.resolved).flatMap { e =>
242+
e.collect {
243+
// Only accepts the case of aliases renaming foldable expressions, e.g.,
244+
// `FoldablePropagation` generates this renaming pattern.
245+
case a: Alias if !a.child.foldable => a.exprId
243246
}
244-
namedExprs.forall { ne =>
245-
!ne.references.filter(_.resolved).map(_.exprId).exists(_ == ne.exprId)
246-
}
247-
}
247+
}.toSet
248+
inputExprIds.intersect(newExprIds).isEmpty
248249
}.forall(identity)
249250
}
250251

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest {
156156
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
157157
val optimized = Optimize.execute(query)
158158
val correctExpand = expand.copy(projections = Seq(
159-
Seq(Literal(null), c2),
160-
Seq(c1, Literal(null))))
159+
Seq(Literal(null), Literal(2)),
160+
Seq(Literal(1), Literal(null))))
161161
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
162162
comparePlans(optimized, correctAnswer)
163163
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LogicalPlanIntegritySuite extends PlanTest {
4343
val Seq(a, b) = t.output
4444
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")())))
4545
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId))))
46-
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
46+
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
4747
assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")())))
4848
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId))))
4949
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId))))

0 commit comments

Comments
 (0)