From 94f25884165e834742d09a5effcd10f2bbe7d4ba Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Wed, 18 Jan 2023 10:30:12 +0800 Subject: [PATCH] Improve AliasAwareOutputExpression --- .../expressions/stringExpressions.scala | 24 ++++ .../plans/AliasAwareOutputExpression.scala | 125 ++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 5 + .../catalyst/plans/logical/LogicalPlan.scala | 13 +- .../plans/logical/basicLogicalOperators.scala | 1 + .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../AliasAwareOutputExpression.scala | 61 +++------ .../spark/sql/execution/SparkPlan.scala | 3 - .../datasources/FileFormatWriter.scala | 2 +- .../sql/execution/datasources/V1Writes.scala | 30 +---- .../org/apache/spark/sql/ExplainSuite.scala | 2 +- .../CoalesceShufflePartitionsSuite.scala | 1 + .../spark/sql/execution/PlannerSuite.scala | 84 +++++++++++- .../datasources/V1WriteCommandSuite.scala | 27 ++-- 14 files changed, 288 insertions(+), 99 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index a3f2ff9e7e84d..b589a474c033f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -3039,3 +3039,27 @@ case class SplitPart ( partNum = newChildren.apply(2)) } } + +/** + * A internal function that converts the empty string to null for partition values. + * This function should be only used in V1Writes. + */ +case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v + + override def nullable: Boolean = true + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + s"""if ($c.numBytes() == 0) { + | ${ev.isNull} = true; + | ${ev.value} = null; + |} else { + | ${ev.value} = $c; + |}""".stripMargin + }) + } + + override protected def withNewChildInternal(newChild: Expression): Empty2Null = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala new file mode 100644 index 0000000000000..cbde89929276b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Empty2Null, Expression, ExpressionSet, NamedExpression, SortOrder} +import org.apache.spark.sql.internal.SQLConf + +/** + * A trait that provides functionality to handle aliases in the `outputExpressions`. + */ +trait AliasAwareOutputExpression extends SQLConfHelper { + private val aliasCandidateLimit = conf.getConf(SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT) + private var _hasAlias = false + protected def outputExpressions: Seq[NamedExpression] + + /** + * This method is used to strip expression which does not affect the result, for example: + * strip the expression which is ordering agnostic for output ordering. + */ + protected def strip(expr: Expression): Expression = expr + + protected lazy val aliasMap: Map[Expression, ArrayBuffer[Attribute]] = { + if (aliasCandidateLimit < 1) { + Map.empty + } else { + val outputExpressionSet = AttributeSet(outputExpressions.map(_.toAttribute)) + val exprWithAliasMap = new mutable.HashMap[Expression, ArrayBuffer[Attribute]]() + + def updateAttrWithAliasMap(key: Expression, target: Attribute): Unit = { + val aliasArray = exprWithAliasMap.getOrElseUpdate( + strip(key).canonicalized, new ArrayBuffer[Attribute]()) + // pre-filter if the number of alias exceed candidate limit + if (aliasArray.size < aliasCandidateLimit) { + aliasArray.append(target) + } + } + + outputExpressions.foreach { + case a @ Alias(child, _) => + _hasAlias = true + updateAttrWithAliasMap(child, a.toAttribute) + case a: Attribute if outputExpressionSet.contains(a) => + updateAttrWithAliasMap(a, a) + case _ => + } + exprWithAliasMap.toMap + } + } + + protected def hasAlias: Boolean = { + aliasMap + _hasAlias + } + + /** + * Return a set of Expression which normalize the original expression to the aliased. + */ + protected def normalizeExpression(expr: Expression): Seq[Expression] = { + val normalizedCandidates = expr.multiTransformDown { + case e: Expression if aliasMap.contains(e.canonicalized) => + val candidates = aliasMap(e.canonicalized) + (candidates :+ e).toStream + }.take(aliasCandidateLimit) + + if (normalizedCandidates.isEmpty) { + expr :: Nil + } else { + normalizedCandidates.toSeq + } + } +} + +/** + * A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that + * satisfies ordering requirements. + */ +trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]] + extends AliasAwareOutputExpression { self: QueryPlan[T] => + protected def orderingExpressions: Seq[SortOrder] + + override protected def strip(expr: Expression): Expression = expr match { + case e: Empty2Null => strip(e.child) + case _ => expr + } + + override final def outputOrdering: Seq[SortOrder] = { + if (hasAlias) { + orderingExpressions.map { sortOrder => + val normalized = normalizeExpression(sortOrder) + assert(normalized.forall(_.isInstanceOf[SortOrder])) + val pruned = ExpressionSet(normalized.flatMap { + case s: SortOrder => s.children.filter(_.references.subsetOf(outputSet)) + }) + if (pruned.isEmpty) { + sortOrder + } else { + // All expressions after pruned are semantics equality, so just use head to build a new + // SortOrder and use tail as the sameOrderExpressions. + SortOrder(pruned.head, sortOrder.direction, sortOrder.nullOrdering, pruned.tail.toSeq) + } + } + } else { + orderingExpressions + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0942919b17677..1d84139e11f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -53,6 +53,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] @transient lazy val outputSet: AttributeSet = AttributeSet(output) + /** + * Returns the output ordering that this plan generates. + */ + def outputOrdering: Seq[SortOrder] = Nil + // Override `treePatternBits` to propagate bits for its expressions. override lazy val treePatternBits: BitSet = { val bits: BitSet = getDefaultTreePatternBits diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7640d9234c71f..eedc16c6998af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -141,11 +141,6 @@ abstract class LogicalPlan */ def refresh(): Unit = children.foreach(_.refresh()) - /** - * Returns the output ordering that this plan generates. - */ - def outputOrdering: Seq[SortOrder] = Nil - /** * Returns true iff `other`'s output is semantically the same, i.e.: * - it contains the same number of `Attribute`s; @@ -205,8 +200,10 @@ trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] { */ trait BinaryNode extends LogicalPlan with BinaryLike[LogicalPlan] -abstract class OrderPreservingUnaryNode extends UnaryNode { - override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +trait OrderPreservingUnaryNode extends UnaryNode + with AliasAwareQueryOutputOrdering[LogicalPlan] { + override protected def outputExpressions: Seq[NamedExpression] = child.output + override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering } object LogicalPlanIntegrity { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 586e344df5ee6..d3d70f4e59b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -69,6 +69,7 @@ object Subquery { case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override protected def outputExpressions: Seq[NamedExpression] = projectList override def maxRows: Option[Long] = child.maxRows override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d949ec56632cd..eff72933249a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -435,6 +435,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val EXPRESSION_PROJECTION_CANDIDATE_LIMIT = + buildConf("spark.sql.optimizer.expressionProjectionCandidateLimit") + .doc("The maximum number of the candidate of out put expressions whose alias are replaced." + + " It can preserve the output partitioning and ordering." + + " Negative value means disable this optimization.") + .version("3.4.0") + .intConf + .createWithDefault(100) + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 92e86637eeccf..c11c0488d6716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -16,48 +16,41 @@ */ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} - -/** - * A trait that provides functionality to handle aliases in the `outputExpressions`. - */ -trait AliasAwareOutputExpression extends UnaryExecNode { - protected def outputExpressions: Seq[NamedExpression] - - private lazy val aliasMap = outputExpressions.collect { - case a @ Alias(child, _) => child.canonicalized -> a.toAttribute - }.toMap - - protected def hasAlias: Boolean = aliasMap.nonEmpty - - protected def normalizeExpression(exp: Expression): Expression = { - exp.transformDown { - case e: Expression => aliasMap.getOrElse(e.canonicalized, e) - } - } -} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} +import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning} /** * A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that * satisfies distribution requirements. */ -trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression { +trait AliasAwareOutputPartitioning extends UnaryExecNode + with AliasAwareOutputExpression { + final override def outputPartitioning: Partitioning = { val normalizedOutputPartitioning = if (hasAlias) { child.outputPartitioning match { case e: Expression => - normalizeExpression(e).asInstanceOf[Partitioning] + val normalized = normalizeExpression(e) + if (normalized.isEmpty) { + UnknownPartitioning(child.outputPartitioning.numPartitions) + } else if (normalized.size == 1) { + normalized.head.asInstanceOf[Partitioning] + } else { + PartitioningCollection(normalized.asInstanceOf[Seq[Partitioning]]) + } case other => other } } else { child.outputPartitioning } - flattenPartitioning(normalizedOutputPartitioning).filter { - case hashPartitioning: HashPartitioning => hashPartitioning.references.subsetOf(outputSet) + val (partitionWithExpr, other) = flattenPartitioning(normalizedOutputPartitioning).filter { + case e: Expression => e.references.subsetOf(outputSet) case _ => true - } match { + }.partition(_.isInstanceOf[Expression]) + val pruned = ExpressionSet(partitionWithExpr.asInstanceOf[Seq[Expression]]) + (pruned.toSeq.asInstanceOf[Seq[Partitioning]] ++ other) match { case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) case Seq(singlePartitioning) => singlePartitioning case seqWithMultiplePartitionings => PartitioningCollection(seqWithMultiplePartitionings) @@ -74,18 +67,4 @@ trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression { } } -/** - * A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that - * satisfies ordering requirements. - */ -trait AliasAwareOutputOrdering extends AliasAwareOutputExpression { - protected def orderingExpressions: Seq[SortOrder] - - final override def outputOrdering: Seq[SortOrder] = { - if (hasAlias) { - orderingExpressions.map(normalizeExpression(_).asInstanceOf[SortOrder]) - } else { - orderingExpressions - } - } -} +trait AliasAwareOutputOrdering extends UnaryExecNode with AliasAwareQueryOutputOrdering[SparkPlan] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 5ca36a8a216af..bbd74a1fe7407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -179,9 +179,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) - /** Specifies how data is ordered in each partition. */ - def outputOrdering: Seq[SortOrder] = Nil - /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 5c4d662c14591..8efc52edbfd48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -158,7 +158,7 @@ object FileFormatWriter extends Logging { // Use the output ordering from the original plan before adding the empty2null projection. val actualOrdering = writeFilesOpt.map(_.child) .getOrElse(materializeAdaptiveSparkPlan(plan)) - .outputOrdering.map(_.child) + .outputOrdering val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering) SQLExecution.checkSQLExecutionId(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index d52af64521855..76167b6004562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder, String2StringExpression, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule @@ -29,7 +28,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType -import org.apache.spark.unsafe.types.UTF8String trait V1WriteCommand extends DataWritingCommand { /** @@ -121,26 +119,6 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper { } object V1WritesUtils { - - /** A function that converts the empty string to null for partition values. */ - case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v - override def nullable: Boolean = true - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => { - s"""if ($c.numBytes() == 0) { - | ${ev.isNull} = true; - | ${ev.value} = null; - |} else { - | ${ev.value} = $c; - |}""".stripMargin - }) - } - - override protected def withNewChildInternal(newChild: Expression): Empty2Null = - copy(child = newChild) - } - def getWriterBucketSpec( bucketSpec: Option[BucketSpec], dataColumns: Seq[Attribute], @@ -230,12 +208,14 @@ object V1WritesUtils { def isOrderingMatched( requiredOrdering: Seq[Expression], - outputOrdering: Seq[Expression]): Boolean = { + outputOrdering: Seq[SortOrder]): Boolean = { if (requiredOrdering.length > outputOrdering.length) { false } else { requiredOrdering.zip(outputOrdering).forall { - case (requiredOrder, outputOrder) => requiredOrder.semanticEquals(outputOrder) + case (requiredOrder, outputOrder) => + // Follow `SortOrder.satisfies` that respects `SortOrder.sameOrderExpressions` + outputOrder.children.exists(_.semanticEquals(requiredOrder)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 9a75cc5ff8f71..5a7f9ac255c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -673,7 +673,7 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit test("SPARK-35133: explain codegen should work with AQE") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { withTempView("df") { - val df = spark.range(5).select(col("id").as("key"), col("id").as("value")) + val df = spark.range(5).selectExpr("id % 2 as key", "id as value") df.createTempView("df") val sqlText = "EXPLAIN CODEGEN SELECT key, MAX(value) FROM df GROUP BY key" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 81777f67f3701..ebf3840aee74e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -340,6 +340,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { // ShuffleQueryStage 2 // ReusedQueryStage 0 val grouped = df.groupBy("key").agg(max("value").as("value")) + .repartition(col("key") + 10) val resultDf2 = grouped.groupBy(col("key") + 1).max("value") .union(grouped.groupBy(col("key") + 2).max("value")) QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Row(2, 1) :: Row(3, 1) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 248c68abd1e0e..66fc6c283af51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1072,7 +1072,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { assert(projects.exists(_.outputPartitioning match { case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _), HashPartitioning(Seq(k2: AttributeReference), _))) => - k1.name == "t1id" && k2.name == "t2id" + Set(k1.name, k2.name) == Set("t1id", "t2id") case _ => false })) } @@ -1101,9 +1101,9 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val projects = collect(planned) { case p: ProjectExec => p } assert(projects.exists(_.outputOrdering match { - case Seq(SortOrder(_, Ascending, NullsFirst, sameOrderExprs)) => - sameOrderExprs.size == 1 && sameOrderExprs.head.isInstanceOf[AttributeReference] && - sameOrderExprs.head.asInstanceOf[AttributeReference].name == "t2id" + case Seq(SortOrder(child, Ascending, NullsFirst, sameOrderExprs)) => + sameOrderExprs.isEmpty && child.isInstanceOf[AttributeReference] && + child.asInstanceOf[AttributeReference].name == "t2id" case _ => false })) } @@ -1249,7 +1249,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { assert(planned.outputPartitioning match { case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _), HashPartitioning(Seq(k2: AttributeReference), _))) => - k1.name == "t1id" && k2.name == "t2id" + Set(k1.name, k2.name) == Set("t1id", "t2id") }) val planned2 = sql( @@ -1314,6 +1314,80 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { assert(topKs.size == 1) assert(sorts.isEmpty) } + + test("SPARK-42049: Improve AliasAwareOutputExpression - ordering - multi-alias") { + Seq(0, 1, 5).foreach { limit => + withSQLConf(SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT.key -> limit.toString) { + val df = spark.range(2).orderBy($"id").selectExpr("id as x", "id as y", "id as z") + val outputOrdering = df.queryExecution.optimizedPlan.outputOrdering + assert(outputOrdering.size == 1) + limit match { + case 5 => + assert(outputOrdering.head.sameOrderExpressions.size == 2) + assert(outputOrdering.head.child.map(_.asInstanceOf[Attribute].name) + .toSet.subsetOf(Set("x", "y", "z"))) + assert(outputOrdering.head.sameOrderExpressions.map(_.asInstanceOf[Attribute].name) + .toSet.subsetOf(Set("x", "y", "z"))) + case 1 => + assert(outputOrdering.head.sameOrderExpressions.isEmpty) + assert(outputOrdering.head.child.map(_.asInstanceOf[Attribute].name) + .toSet.subsetOf(Set("x", "y", "z"))) + case 0 => + assert(outputOrdering.head.sameOrderExpressions.isEmpty) + } + } + } + } + + test("SPARK-42049: Improve AliasAwareOutputExpression - partitioning - multi-alias") { + Seq(0, 1, 5).foreach { limit => + withSQLConf(SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT.key -> limit.toString) { + val df = spark.range(2).repartition($"id").selectExpr("id as x", "id as y", "id as z") + val outputPartitioning = stripAQEPlan(df.queryExecution.executedPlan).outputPartitioning + limit match { + case 5 => + val p = outputPartitioning.asInstanceOf[PartitioningCollection].partitionings + assert(p.size == 3) + assert(p.flatMap(_.asInstanceOf[HashPartitioning].expressions + .map(_.asInstanceOf[Attribute].name)).toSet == Set("x", "y", "z")) + case 1 => + val p = outputPartitioning.asInstanceOf[HashPartitioning] + assert(p.expressions.size == 1) + assert(p.expressions.map(_.asInstanceOf[Attribute].name) + .toSet.subsetOf(Set("x", "y", "z"))) + case 0 => + // the references of child output partitioning is not the subset of output, + // so it has been pruned + assert(outputPartitioning.isInstanceOf[UnknownPartitioning]) + } + } + } + } + + test("SPARK-42049: Improve AliasAwareOutputExpression - ordering - multi-references") { + val df = spark.range(2).selectExpr("id as a", "id as b") + .orderBy($"a" + $"b").selectExpr("a as x", "b as y") + val outputOrdering = df.queryExecution.optimizedPlan.outputOrdering + assert(outputOrdering.size == 1) + assert(outputOrdering.head.sameOrderExpressions.isEmpty) + // (a + b), (a + y), (x + b) are pruned since their references are not the subset of output + outputOrdering.head.child match { + case Add(l: Attribute, r: Attribute, _) => assert(l.name == "x" && r.name == "y") + case _ => fail(s"Unexpected ${outputOrdering.head.sameOrderExpressions.head}") + } + } + + test("SPARK-42049: Improve AliasAwareOutputExpression - partitioning - multi-references") { + val df = spark.range(2).selectExpr("id as a", "id as b") + .repartition($"a" + $"b").selectExpr("a as x", "b as y") + val outputPartitioning = stripAQEPlan(df.queryExecution.executedPlan).outputPartitioning + // (a + b), (a + y), (x + b) are pruned since their references are not the subset of output + outputPartitioning match { + case HashPartitioning(Seq(Add(l: Attribute, r: Attribute, _)), _) => + assert(l.name == "x" && r.name == "y") + case _ => fail(s"Unexpected $outputPartitioning") + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index 40574a8e73aa2..20a90cd94b65b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -90,9 +90,11 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils { sparkContext.listenerBus.waitUntilEmpty() assert(optimizedPlan != null) - // Check whether a logical sort node is at the top of the logical plan of the write query. - assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort, - s"Expect hasLogicalSort: $hasLogicalSort, Actual: ${optimizedPlan.isInstanceOf[Sort]}") + // Check whether exists a logical sort node of the write query. + // If user specified sort matches required ordering, the sort node may not at the top of query. + assert(optimizedPlan.exists(_.isInstanceOf[Sort]) == hasLogicalSort, + s"Expect hasLogicalSort: $hasLogicalSort," + + s"Actual: ${optimizedPlan.exists(_.isInstanceOf[Sort])}") // Check empty2null conversion. val empty2nullExpr = optimizedPlan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions)) @@ -223,8 +225,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write case s: SortExec => s }.exists { case SortExec(Seq( - SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) ), false, _, _) => true case _ => false }, plan) @@ -268,16 +270,11 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write // assert the outer most sort in the executed plan assert(plan.collectFirst { case s: SortExec => s - }.map(s => (enabled, s)).exists { - case (false, SortExec(Seq( - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), - SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) - ), false, _, _)) => true - - // SPARK-40885: this bug removes the in-partition sort, which manifests here - case (true, SortExec(Seq( - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) - ), false, _, _)) => true + }.exists { + case SortExec(Seq( + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true case _ => false }, plan) }