Skip to content

Commit dfd7ac9

Browse files
viiryagatorsmile
authored andcommitted
[SPARK-24781][SQL] Using a reference from Dataset in Filter/Sort might not work
## What changes were proposed in this pull request? When we use a reference from Dataset in filter or sort, which was not used in the prior select, an AnalysisException occurs, e.g., ```scala val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") df.select(df("name")).filter(df("id") === 0).show() ``` ```scala org.apache.spark.sql.AnalysisException: Resolved attribute(s) id#6 missing from name#5 in operator !Filter (id#6 = 0).;; !Filter (id#6 = 0) +- AnalysisBarrier +- Project [name#5] +- Project [_1#2 AS name#5, _2#3 AS id#6] +- LocalRelation [_1#2, _2#3] ``` This change updates the rule `ResolveMissingReferences` so `Filter` and `Sort` with non-empty `missingInputs` will also be transformed. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <[email protected]> Closes #21745 from viirya/SPARK-24781.
1 parent 0f24c6f commit dfd7ac9

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,8 @@ class Analyzer(
11321132
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
11331133
case sa @ Sort(_, _, child: Aggregate) => sa
11341134

1135-
case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
1135+
case s @ Sort(order, _, child)
1136+
if (!s.resolved || s.missingInput.nonEmpty) && child.resolved =>
11361137
val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
11371138
val ordering = newOrder.map(_.asInstanceOf[SortOrder])
11381139
if (child.output == newChild.output) {
@@ -1143,7 +1144,7 @@ class Analyzer(
11431144
Project(child.output, newSort)
11441145
}
11451146

1146-
case f @ Filter(cond, child) if !f.resolved && child.resolved =>
1147+
case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
11471148
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
11481149
if (child.output == newChild.output) {
11491150
f.copy(condition = newCond.head)
@@ -1154,10 +1155,17 @@ class Analyzer(
11541155
}
11551156
}
11561157

1158+
/**
1159+
* This method tries to resolve expressions and find missing attributes recursively. Specially,
1160+
* when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
1161+
* attributes which are missed from child output. This method tries to find the missing
1162+
* attributes out and add into the projection.
1163+
*/
11571164
private def resolveExprsAndAddMissingAttrs(
11581165
exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
1159-
if (exprs.forall(_.resolved)) {
1160-
// All given expressions are resolved, no need to continue anymore.
1166+
// Missing attributes can be unresolved attributes or resolved attributes which are not in
1167+
// the output attributes of the plan.
1168+
if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) {
11611169
(exprs, plan)
11621170
} else {
11631171
plan match {
@@ -1168,15 +1176,19 @@ class Analyzer(
11681176
(newExprs, AnalysisBarrier(newChild))
11691177

11701178
case p: Project =>
1179+
// Resolving expressions against current plan.
11711180
val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
1181+
// Recursively resolving expressions on the child of current plan.
11721182
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
1173-
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
1183+
// If some attributes used by expressions are resolvable only on the rewritten child
1184+
// plan, we need to add them into original projection.
1185+
val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet)
11741186
(newExprs, Project(p.projectList ++ missingAttrs, newChild))
11751187

11761188
case a @ Aggregate(groupExprs, aggExprs, child) =>
11771189
val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
11781190
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
1179-
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
1191+
val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
11801192
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
11811193
// All the missing attributes are grouping expressions, valid case.
11821194
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
@@ -1526,7 +1538,11 @@ class Analyzer(
15261538

15271539
// Try resolving the ordering as though it is in the aggregate clause.
15281540
try {
1529-
val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
1541+
// If a sort order is unresolved, containing references not in aggregate, or containing
1542+
// `AggregateExpression`, we need to push down it to the underlying aggregate operator.
1543+
val unresolvedSortOrders = sortOrder.filter { s =>
1544+
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
1545+
}
15301546
val aliasedOrdering =
15311547
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
15321548
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
23872387
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
23882388
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
23892389
}
2390+
2391+
test("SPARK-24781: Using a reference from Dataset in Filter/Sort") {
2392+
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
2393+
val filter1 = df.select(df("name")).filter(df("id") === 0)
2394+
val filter2 = df.select(col("name")).filter(col("id") === 0)
2395+
checkAnswer(filter1, filter2.collect())
2396+
2397+
val sort1 = df.select(df("name")).orderBy(df("id"))
2398+
val sort2 = df.select(col("name")).orderBy(col("id"))
2399+
checkAnswer(sort1, sort2.collect())
2400+
}
2401+
2402+
test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") {
2403+
withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
2404+
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
2405+
2406+
val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
2407+
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
2408+
checkAnswer(aggPlusSort1, aggPlusSort2.collect())
2409+
2410+
val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
2411+
val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
2412+
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
2413+
}
2414+
}
23902415
}

0 commit comments

Comments
 (0)