diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c078efdfc0000..83eb52d05af64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1129,7 +1129,8 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1140,7 +1141,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1151,10 +1152,17 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { @@ -1165,15 +1173,19 @@ class Analyzer( (newExprs, AnalysisBarrier(newChild)) case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1493,7 +1505,11 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9d7645d232d08..5babdf6f33b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2387,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } }