diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 828a129a876ec..62a75e753455a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -188,7 +188,7 @@ case class AdaptiveSparkPlanExec( @volatile private var currentPhysicalPlan = initialPlan - private var isFinalPlan = false + @volatile private var _isFinalPlan = false private var currentStageId = 0 @@ -205,6 +205,8 @@ case class AdaptiveSparkPlanExec( def executedPlan: SparkPlan = currentPhysicalPlan + def isFinalPlan: Boolean = _isFinalPlan + override def conf: SQLConf = context.session.sessionState.conf override def output: Seq[Attribute] = inputPlan.output @@ -329,7 +331,7 @@ case class AdaptiveSparkPlanExec( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) - isFinalPlan = true + _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index da9316efdb417..0f00a6a3559b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.vectorized.ColumnarBatch @@ -111,10 +112,15 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes + private def cachedPlan = relation.cachedPlan match { + case adaptive: AdaptiveSparkPlanExec if adaptive.isFinalPlan => adaptive.executedPlan + case other => other + } + private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) + val attrMap = AttributeMap(cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -123,7 +129,7 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.cachedPlan.outputPartitioning match { + cachedPlan.outputPartitioning match { case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] case other => other } @@ -132,7 +138,7 @@ case class InMemoryTableScanExec( // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) lazy val enableAccumulatorsForTest: Boolean = conf.inMemoryTableScanStatisticsEnabled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index fabd0a4e1a951..6be1e424719e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -40,10 +40,10 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.FakeV2Provider -import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -3513,6 +3513,25 @@ class DataFrameSuite extends QueryTest assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2) } } + + test("SPARK-41048: Improve output partitioning and ordering with AQE cache") { + withSQLConf( + SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(10).selectExpr("cast(id as string) c1") + val df2 = spark.range(10).selectExpr("cast(id as string) c2") + val cached = df1.join(df2, $"c1" === $"c2").cache() + cached.count() + val executedPlan = cached.groupBy("c1").agg(max($"c2")).queryExecution.executedPlan + // before is 2 sort and 1 shuffle + assert(collect(executedPlan) { + case s: ShuffleExchangeLike => s + }.isEmpty) + assert(collect(executedPlan) { + case s: SortExec => s + }.isEmpty) + } + } } case class GroupByKey(a: Int, b: Int)