Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic)
PullOutNondeterministic),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)

/**
Expand Down Expand Up @@ -146,8 +148,6 @@ class Analyzer(
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
case g: GetStructField => Alias(g, g.field.name)()
case g: GetArrayStructFields => Alias(g, g.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
Expand Down Expand Up @@ -384,9 +384,7 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) {
q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
}
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
Expand All @@ -412,11 +410,6 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
case UnresolvedAlias(child) => child
case other => other
}

private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
Expand All @@ -426,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
plan.resolve(nameParts, resolver).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
Expand Down Expand Up @@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}

/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
* Window(window expressions).
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
var stop = false
e.transformDown {
// CreateStruct is a special case, we need to retain its top level Aliases as they decide the
// name of StructField. We also need to stop transform down this expression, or the Aliases
// under CreateStruct will be mistakenly trimmed.
case c: CreateStruct if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case c: CreateStructUnsafe if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case Alias(child, _) if !stop => child
}
}

def trimNonTopLevelAliases(e: Expression): Expression = e match {
case a: Alias =>
Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
case other => trimAliases(other)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Project(cleanedProjectList, child)

case Aggregate(grouping, aggs, child) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Aggregate(grouping.map(trimAliases), cleanedAggs, child)

case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs =
windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)

case other =>
var stop = false
other transformExpressionsDown {
case c: CreateStruct if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case c: CreateStructUnsafe if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case Alias(child, _) if !stop => child
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override lazy val resolved: Boolean = childrenResolved

override lazy val dataType: StructType = {
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
Expand Down Expand Up @@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]

Project(substitutedProjection, child)
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
val cleanedProjection = substitutedProjection.map(p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
)
Project(cleanedProjection, child)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and wrap it with UnresolvedAlias which will be removed later.
// and aliased it with the last part of the name.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
// expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
Some(UnresolvedAlias(fieldExprs))
Some(Alias(fieldExprs, nestedFields.last)())

// No matches.
case Seq() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ case class Window(
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] =
(projectList ++ windowExpressions).map(_.toAttribute)
projectList ++ windowExpressions.map(_.toAttribute)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest {
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}

test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
val a = testRelation.output.head
var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
var expected = testRelation.select((a + 1 + 2).as("col"))
checkAnalysis(plan, expected)

plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
checkAnalysis(plan, expected)

// CreateStruct is a special case that we should not trim Alias for it.
plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
checkAnalysis(plan, plan)
plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
checkAnalysis(plan, plan)
}
}
16 changes: 14 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* df.select($"colA".as("colB"))
* }}}
*
* If the current column has metadata associated with it, this metadata will be propagated
* to the new column. If this not desired, use `as` with explicitly empty metadata.
*
* @group expr_ops
* @since 1.3.0
*/
def as(alias: String): Column = Alias(expr, alias)()
def as(alias: String): Column = expr match {
case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
case other => Alias(other, alias)()
}

/**
* (Scala-specific) Assigns the given aliases to the results of a table generating function.
Expand Down Expand Up @@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* df.select($"colA".as('colB))
* }}}
*
* If the current column has metadata associated with it, this metadata will be propagated
* to the new column. If this not desired, use `as` with explicitly empty metadata.
*
* @group expr_ops
* @since 1.3.0
*/
def as(alias: Symbol): Column = Alias(expr, alias.name)()
def as(alias: Symbol): Column = expr match {
case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata))
case other => Alias(other, alias.name)()
}

/**
* Gives the column an alias with metadata.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.scalatest.Matchers._

import org.apache.spark.sql.execution.{Project, TungstenProject}
Expand Down Expand Up @@ -106,6 +107,14 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
assert(df.select(df("a").alias("b")).columns.head === "b")
}

test("as propagates metadata") {
val metadata = new MetadataBuilder
metadata.putString("key", "value")
val origCol = $"a".as("b", metadata.build())
val newCol = origCol.as("c")
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
}

test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,4 +871,10 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
assert(expected === actual)
}

test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
checkAnswer(df.orderBy("a.b"), Row(Row(1)))
}
}