Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,8 @@ class Analyzer(
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa

case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
case s @ Sort(order, _, child)
if (!s.resolved || s.missingInput.nonEmpty) && child.resolved =>
val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
val ordering = newOrder.map(_.asInstanceOf[SortOrder])
if (child.output == newChild.output) {
Expand All @@ -1140,7 +1141,7 @@ class Analyzer(
Project(child.output, newSort)
}

case f @ Filter(cond, child) if !f.resolved && child.resolved =>
case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
if (child.output == newChild.output) {
f.copy(condition = newCond.head)
Expand All @@ -1151,10 +1152,17 @@ class Analyzer(
}
}

/**
* This method tries to resolve expressions and find missing attributes recursively. Specially,
* when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
* attributes which are missed from child output. This method tries to find the missing
* attributes out and add into the projection.
*/
private def resolveExprsAndAddMissingAttrs(
exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
if (exprs.forall(_.resolved)) {
// All given expressions are resolved, no need to continue anymore.
// Missing attributes can be unresolved attributes or resolved attributes which are not in
// the output attributes of the plan.
if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) {
(exprs, plan)
} else {
plan match {
Expand All @@ -1165,15 +1173,19 @@ class Analyzer(
(newExprs, AnalysisBarrier(newChild))

case p: Project =>
// Resolving expressions against current plan.
val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
// Recursively resolving expressions on the child of current plan.
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
// If some attributes used by expressions are resolvable only on the rewritten child
// plan, we need to add them into original projection.
val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we do not do the .intersect(newChild.outputSet)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this intersect, some tests fail, e.g., group-analytics.sql in SQLQueryTestSuite. Some attributes are resolved on parent plans, not on child plans. We can't add them as missing attributes here.

(newExprs, Project(p.projectList ++ missingAttrs, newChild))

case a @ Aggregate(groupExprs, aggExprs, child) =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
Expand Down Expand Up @@ -1493,7 +1505,11 @@ class Analyzer(

// Try resolving the ordering as though it is in the aggregate clause.
try {
val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
// If a sort order is unresolved, containing references not in aggregate, or containing
// `AggregateExpression`, we need to push down it to the underlying aggregate operator.
val unresolvedSortOrders = sortOrder.filter { s =>
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
}
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2387,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
}

test("SPARK-24781: Using a reference from Dataset in Filter/Sort") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
val filter1 = df.select(df("name")).filter(df("id") === 0)
val filter2 = df.select(col("name")).filter(col("id") === 0)
checkAnswer(filter1, filter2.collect())

val sort1 = df.select(df("name")).orderBy(df("id"))
val sort2 = df.select(col("name")).orderBy(col("id"))
checkAnswer(sort1, sort2.collect())
}

test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") {
withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")

val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
checkAnswer(aggPlusSort1, aggPlusSort2.collect())

val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
}
}
}