diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala index 4d9d69d14fe5f..cfe229945929c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala @@ -95,7 +95,7 @@ trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]] } override final def outputOrdering: Seq[SortOrder] = { - if (hasAlias) { + val newOrdering: Iterator[Option[SortOrder]] = if (hasAlias) { // Take the first `SortOrder`s only until they can be projected. // E.g. we have child ordering `Seq(SortOrder(a), SortOrder(b))` then // if only `a AS x` can be projected then we can return Seq(SortOrder(x))` @@ -112,9 +112,21 @@ trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]] } else { None } - }.takeWhile(_.isDefined).flatten.toSeq + } } else { - orderingExpressions + // Make sure the returned ordering are valid (only reference output attributes of the current + // plan node). Same as above (the if branch), we take the first ordering expressions that are + // all valid. + val outputSet = AttributeSet(outputExpressions.map(_.toAttribute)) + orderingExpressions.iterator.map { order => + val validChildren = order.children.filter(_.references.subsetOf(outputSet)) + if (validChildren.nonEmpty) { + Some(order.copy(child = validChildren.head, sameOrderExpressions = validChildren.tail)) + } else { + None + } + } } + newOrdering.takeWhile(_.isDefined).flatten.toSeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index a09c719cf840d..e1dcab80af307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression} import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning} @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningC trait PartitioningPreservingUnaryExecNode extends UnaryExecNode with AliasAwareOutputExpression { final override def outputPartitioning: Partitioning = { - if (hasAlias) { + val partitionings: Seq[Partitioning] = if (hasAlias) { flattenPartitioning(child.outputPartitioning).flatMap { case e: Expression => // We need unique partitionings but if the input partitioning is @@ -44,13 +44,19 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode .take(aliasCandidateLimit) .asInstanceOf[Stream[Partitioning]] case o => Seq(o) - } match { - case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) - case Seq(p) => p - case ps => PartitioningCollection(ps) } } else { - child.outputPartitioning + // Filter valid partitiongs (only reference output attributes of the current plan node) + val outputSet = AttributeSet(outputExpressions.map(_.toAttribute)) + flattenPartitioning(child.outputPartitioning).filter { + case e: Expression => e.references.subsetOf(outputSet) + case _ => true + } + } + partitionings match { + case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) + case Seq(p) => p + case ps => PartitioningCollection(ps) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index e9cb77ec95c4d..4b3d3a4b8058a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1129,9 +1129,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { assert(sortNodes.size == 3) val outputOrdering = planned.outputOrdering assert(outputOrdering.size == 1) - // Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same - assert(outputOrdering.head.children.size == 3) - assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2) + // Sort order should have 2 childrens, not 4. This is because t1.id*2 and 2*t1.id are same + // and t2.id is not a valid ordering (the final plan doesn't output t2.id) + assert(outputOrdering.head.children.size == 2) + assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 1) assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index e4ecdb9c44595..f5839e9975602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -177,4 +177,18 @@ class ProjectedOrderingAndPartitioningSuite val outputOrdering3 = df3.queryExecution.optimizedPlan.outputOrdering assert(outputOrdering3.size == 0) } + + test("SPARK-42049: Improve AliasAwareOutputExpression - no alias but still prune expressions") { + val df = spark.range(2).select($"id" + 1 as "a", $"id" + 2 as "b") + + val df1 = df.repartition($"a", $"b").selectExpr("a") + val outputPartitioning = stripAQEPlan(df1.queryExecution.executedPlan).outputPartitioning + assert(outputPartitioning.isInstanceOf[UnknownPartitioning]) + + val df2 = df.orderBy("a", "b").select("a") + val outputOrdering = df2.queryExecution.optimizedPlan.outputOrdering + assert(outputOrdering.size == 1) + assert(outputOrdering.head.child.asInstanceOf[Attribute].name == "a") + assert(outputOrdering.head.sameOrderExpressions.size == 0) + } }