Skip to content

Commit ec29f20

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-9634] [SPARK-9323] [SQL] cleanup unnecessary Aliases in LogicalPlan at the end of analysis
Also alias the ExtractValue instead of wrapping it with UnresolvedAlias when resolve attribute in LogicalPlan, as this alias will be trimmed if it's unnecessary. Based on #7957 without the changes to mllib, but instead maintaining earlier behavior when using `withColumn` on expressions that already have metadata. Author: Wenchen Fan <[email protected]> Author: Michael Armbrust <[email protected]> Closes #8215 from marmbrus/pr/7957.
1 parent 37586e5 commit ec29f20

File tree

9 files changed

+120
-24
lines changed

9 files changed

+120
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class Analyzer(
8282
HiveTypeCoercion.typeCoercionRules ++
8383
extendedResolutionRules : _*),
8484
Batch("Nondeterministic", Once,
85-
PullOutNondeterministic)
85+
PullOutNondeterministic),
86+
Batch("Cleanup", fixedPoint,
87+
CleanupAliases)
8688
)
8789

8890
/**
@@ -146,8 +148,6 @@ class Analyzer(
146148
child match {
147149
case _: UnresolvedAttribute => u
148150
case ne: NamedExpression => ne
149-
case g: GetStructField => Alias(g, g.field.name)()
150-
case g: GetArrayStructFields => Alias(g, g.field.name)()
151151
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
152152
case e if !e.resolved => u
153153
case other => Alias(other, s"_c$i")()
@@ -384,9 +384,7 @@ class Analyzer(
384384
case u @ UnresolvedAttribute(nameParts) =>
385385
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
386386
val result =
387-
withPosition(u) {
388-
q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
389-
}
387+
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
390388
logDebug(s"Resolving $u to $result")
391389
result
392390
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -412,11 +410,6 @@ class Analyzer(
412410
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
413411
}
414412

415-
private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
416-
case UnresolvedAlias(child) => child
417-
case other => other
418-
}
419-
420413
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
421414
ordering.map { order =>
422415
// Resolve SortOrder in one round.
@@ -426,7 +419,7 @@ class Analyzer(
426419
try {
427420
val newOrder = order transformUp {
428421
case u @ UnresolvedAttribute(nameParts) =>
429-
plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
422+
plan.resolve(nameParts, resolver).getOrElse(u)
430423
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
431424
ExtractValue(child, fieldName, resolver)
432425
}
@@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
968961
case Subquery(_, child) => child
969962
}
970963
}
964+
965+
/**
966+
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
967+
* expression in Project(project list) or Aggregate(aggregate expressions) or
968+
* Window(window expressions).
969+
*/
970+
object CleanupAliases extends Rule[LogicalPlan] {
971+
private def trimAliases(e: Expression): Expression = {
972+
var stop = false
973+
e.transformDown {
974+
// CreateStruct is a special case, we need to retain its top level Aliases as they decide the
975+
// name of StructField. We also need to stop transform down this expression, or the Aliases
976+
// under CreateStruct will be mistakenly trimmed.
977+
case c: CreateStruct if !stop =>
978+
stop = true
979+
c.copy(children = c.children.map(trimNonTopLevelAliases))
980+
case c: CreateStructUnsafe if !stop =>
981+
stop = true
982+
c.copy(children = c.children.map(trimNonTopLevelAliases))
983+
case Alias(child, _) if !stop => child
984+
}
985+
}
986+
987+
def trimNonTopLevelAliases(e: Expression): Expression = e match {
988+
case a: Alias =>
989+
Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
990+
case other => trimAliases(other)
991+
}
992+
993+
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
994+
case Project(projectList, child) =>
995+
val cleanedProjectList =
996+
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
997+
Project(cleanedProjectList, child)
998+
999+
case Aggregate(grouping, aggs, child) =>
1000+
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
1001+
Aggregate(grouping.map(trimAliases), cleanedAggs, child)
1002+
1003+
case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
1004+
val cleanedWindowExprs =
1005+
windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
1006+
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
1007+
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
1008+
1009+
case other =>
1010+
var stop = false
1011+
other transformExpressionsDown {
1012+
case c: CreateStruct if !stop =>
1013+
stop = true
1014+
c.copy(children = c.children.map(trimNonTopLevelAliases))
1015+
case c: CreateStructUnsafe if !stop =>
1016+
stop = true
1017+
c.copy(children = c.children.map(trimNonTopLevelAliases))
1018+
case Alias(child, _) if !stop => child
1019+
}
1020+
}
1021+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
7575

7676
override def foldable: Boolean = children.forall(_.foldable)
7777

78-
override lazy val resolved: Boolean = childrenResolved
79-
8078
override lazy val dataType: StructType = {
8179
val fields = children.zipWithIndex.map { case (child, idx) =>
8280
child match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.immutable.HashSet
21-
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
21+
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans.Inner
2424
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
260260
val substitutedProjection = projectList1.map(_.transform {
261261
case a: Attribute => aliasMap.getOrElse(a, a)
262262
}).asInstanceOf[Seq[NamedExpression]]
263-
264-
Project(substitutedProjection, child)
263+
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
264+
val cleanedProjection = substitutedProjection.map(p =>
265+
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
266+
)
267+
Project(cleanedProjection, child)
265268
}
266269
}
267270
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
259259
// One match, but we also need to extract the requested nested field.
260260
case Seq((a, nestedFields)) =>
261261
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
262-
// and wrap it with UnresolvedAlias which will be removed later.
262+
// and aliased it with the last part of the name.
263263
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
264-
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
265-
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
264+
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
265+
// expression as "c".
266266
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
267267
ExtractValue(expr, Literal(fieldName), resolver))
268-
Some(UnresolvedAlias(fieldExprs))
268+
Some(Alias(fieldExprs, nestedFields.last)())
269269

270270
// No matches.
271271
case Seq() =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ case class Window(
228228
child: LogicalPlan) extends UnaryNode {
229229

230230
override def output: Seq[Attribute] =
231-
(projectList ++ windowExpressions).map(_.toAttribute)
231+
projectList ++ windowExpressions.map(_.toAttribute)
232232
}
233233

234234
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest {
119119
Project(testRelation.output :+ projected, testRelation)))
120120
checkAnalysis(plan, expected)
121121
}
122+
123+
test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
124+
val a = testRelation.output.head
125+
var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
126+
var expected = testRelation.select((a + 1 + 2).as("col"))
127+
checkAnalysis(plan, expected)
128+
129+
plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
130+
expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
131+
checkAnalysis(plan, expected)
132+
133+
// CreateStruct is a special case that we should not trim Alias for it.
134+
plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
135+
checkAnalysis(plan, plan)
136+
plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
137+
checkAnalysis(plan, plan)
138+
}
122139
}

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
753753
* df.select($"colA".as("colB"))
754754
* }}}
755755
*
756+
* If the current column has metadata associated with it, this metadata will be propagated
757+
* to the new column. If this not desired, use `as` with explicitly empty metadata.
758+
*
756759
* @group expr_ops
757760
* @since 1.3.0
758761
*/
759-
def as(alias: String): Column = Alias(expr, alias)()
762+
def as(alias: String): Column = expr match {
763+
case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
764+
case other => Alias(other, alias)()
765+
}
760766

761767
/**
762768
* (Scala-specific) Assigns the given aliases to the results of a table generating function.
@@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
789795
* df.select($"colA".as('colB))
790796
* }}}
791797
*
798+
* If the current column has metadata associated with it, this metadata will be propagated
799+
* to the new column. If this not desired, use `as` with explicitly empty metadata.
800+
*
792801
* @group expr_ops
793802
* @since 1.3.0
794803
*/
795-
def as(alias: Symbol): Column = Alias(expr, alias.name)()
804+
def as(alias: Symbol): Column = expr match {
805+
case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata))
806+
case other => Alias(other, alias.name)()
807+
}
796808

797809
/**
798810
* Gives the column an alias with metadata.

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.expressions.NamedExpression
2021
import org.scalatest.Matchers._
2122

2223
import org.apache.spark.sql.execution.{Project, TungstenProject}
@@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
110111
assert(df.select(df("a").alias("b")).columns.head === "b")
111112
}
112113

114+
test("as propagates metadata") {
115+
val metadata = new MetadataBuilder
116+
metadata.putString("key", "value")
117+
val origCol = $"a".as("b", metadata.build())
118+
val newCol = origCol.as("c")
119+
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
120+
}
121+
113122
test("single explode") {
114123
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
115124
checkAnswer(

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
867867
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
868868
assert(expected === actual)
869869
}
870+
871+
test("SPARK-9323: DataFrame.orderBy should support nested column name") {
872+
val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
873+
"""{"a": {"b": 1}}""" :: Nil))
874+
checkAnswer(df.orderBy("a.b"), Row(Row(1)))
875+
}
870876
}

0 commit comments

Comments
 (0)