diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala new file mode 100644 index 0000000000000..b7d501725705c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.physical._ + +trait AliasAwareOutputPartitioning extends UnaryExecNode { + + protected def outputExpressions: Seq[NamedExpression] + + // If projects and aggregates have aliases in output expressions, we should respect + // these aliases so as to check if the operators satisfy their output distribution requirements. + // If we don't respect aliases, this rule wrongly adds shuffle operations, e.g., + // + // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df1") + // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df2") + // sql(""" + // SELECT * FROM + // (SELECT key AS k from df1) t1 + // INNER JOIN + // (SELECT key AS k from df2) t2 + // ON t1.k = t2.k + // """).explain + // + // == Physical Plan == + // *SortMergeJoin [k#56L], [k#57L], Inner + // :- *Sort [k#56L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(k#56L, 200) // <--- Unnecessary shuffle operation + // : +- *Project [key#39L AS k#56L] + // : +- Exchange hashpartitioning(key#39L, 200) + // : +- *Project [id#36L AS key#39L] + // : +- *Range (0, 10, step=1, splits=Some(4)) + // +- *Sort [k#57L ASC NULLS FIRST], false, 0 + // +- ReusedExchange [k#57L], Exchange hashpartitioning(k#56L, 200) + final override def outputPartitioning: Partitioning = if (hasAlias(outputExpressions)) { + resolveOutputPartitioningByAliases(outputExpressions, child.outputPartitioning) + } else { + child.outputPartitioning + } + + private def hasAlias(exprs: Seq[NamedExpression]): Boolean = + exprs.exists(_.collectFirst { case _: Alias => true }.isDefined) + + private def resolveOutputPartitioningByAliases( + exprs: Seq[NamedExpression], + partitioning: Partitioning): Partitioning = { + val aliasSeq = exprs.flatMap(_.collectFirst { + case a @ Alias(child, _) => (child, a.toAttribute) + }) + def mayReplaceExprWithAlias(e: Expression): Expression = { + aliasSeq.find { case (c, _) => c.semanticEquals(e) }.map(_._2).getOrElse(e) + } + def mayReplacePartitioningExprsWithAliases(p: Partitioning): Partitioning = p match { + case hash @ HashPartitioning(exprs, _) => + hash.copy(expressions = exprs.map(mayReplaceExprWithAlias)) + case range @ RangePartitioning(ordering, _) => + range.copy(ordering = ordering.map { order => + order.copy( + child = mayReplaceExprWithAlias(order.child), + sameOrderExpressions = order.sameOrderExpressions.map(mayReplaceExprWithAlias) + ) + }) + case _ => p + } + + partitioning match { + case pc @ PartitioningCollection(ps) => + pc.copy(partitionings = ps.map(mayReplacePartitioningExprsWithAliases)) + case _ => + mayReplacePartitioningExprsWithAliases(partitioning) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2cac0cfce28de..885f28ae6d65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -45,7 +45,9 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning { + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -66,8 +68,6 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning - override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..aefb8e428ebda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -65,7 +65,9 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -95,8 +97,6 @@ case class ObjectHashAggregateExec( } } - override def outputPartitioning: Partitioning = child.outputPartitioning - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..3d3425d413ed7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.Utils @@ -38,7 +38,9 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -66,8 +68,6 @@ case class SortAggregateExec( groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 9434ceb7cd16c..0f9a137717754 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -33,10 +33,12 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override protected def outputExpressions: Seq[NamedExpression] = projectList + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } @@ -76,8 +78,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d2d5011bbcb97..fb7edc18c1046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 2c18d6aaabdba..268b734d4b696 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -73,3 +73,9 @@ where b.z != b.z; -- SPARK-24369 multiple distinct aggregations having the same argument set SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); + +-- SPARK-19981 Correctly resolve partitioning when output has aliases +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 'a'), (1, 'b') AS (a, b) DISTRIBUTE BY a; +EXPLAIN SELECT k, MAX(b) FROM (SELECT a AS k, b FROM t1) t GROUP BY k; +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 2), (0, 3) AS (a, b); +EXPLAIN SELECT k, COUNT(v) FROM (SELECT a AS k, MAX(b) AS v FROM t2 GROUP BY a) t GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql index 38739cb950582..2c5f5e01e5984 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql @@ -15,3 +15,10 @@ SELECT a, 'b' AS tag FROM t4; -- SPARK-19766 Constant alias columns in INNER JOIN should not be folded by FoldablePropagation rule SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag; + +-- SPARK-19981 Correctly resolve partitioning when output has aliases +SET spark.sql.autoBroadcastJoinThreshold = -1; +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES (1, 1), (3, 0) AS (k, v) DISTRIBUTE BY (k); +CREATE TEMPORARY VIEW t6 AS SELECT * FROM VALUES (1, 1), (5, 1) AS (k, v) DISTRIBUTE BY (k); +EXPLAIN SELECT * FROM (SELECT k AS k1 FROM t5) t5a +INNER JOIN (SELECT k AS k1 FROM t6) t6a ON t5a.k1 = t6a.k1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 581aa1754ce14..37c809104ddbe 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 31 -- !query 0 @@ -250,3 +250,47 @@ SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) struct -- !query 26 output 1.0 1.0 3 + + +-- !query 27 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 'a'), (1, 'b') AS (a, b) DISTRIBUTE BY a +-- !query 27 schema +struct<> +-- !query 27 output + + + +-- !query 28 +EXPLAIN SELECT k, MAX(b) FROM (SELECT a AS k, b FROM t1) t GROUP BY k +-- !query 28 schema +struct +-- !query 28 output +== Physical Plan == +SortAggregate(key=[k#x], functions=[max(b#x)]) ++- SortAggregate(key=[k#x], functions=[partial_max(b#x)]) + +- *Sort [k#x ASC NULLS FIRST], false, 0 + +- *Project [a#x AS k#x, b#x] + +- Exchange hashpartitioning(a#x, 200) + +- LocalTableScan [a#x, b#x] + + +-- !query 29 +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 2), (0, 3) AS (a, b) +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +EXPLAIN SELECT k, COUNT(v) FROM (SELECT a AS k, MAX(b) AS v FROM t2 GROUP BY a) t GROUP BY k +-- !query 30 schema +struct +-- !query 30 output +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[count(v#x)]) ++- *HashAggregate(keys=[k#x], functions=[partial_count(v#x)]) + +- *HashAggregate(keys=[a#x], functions=[max(b#x)]) + +- Exchange hashpartitioning(a#x, 200) + +- *HashAggregate(keys=[a#x], functions=[partial_max(b#x)]) + +- LocalTableScan [a#x, b#x] diff --git a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out index 8d56ebe9fd3b4..38c3e20bdb4d1 100644 --- a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 11 -- !query 0 @@ -65,3 +65,45 @@ struct 1 a 1 b 1 b + + +-- !query 7 +SET spark.sql.autoBroadcastJoinThreshold = -1 +-- !query 7 schema +struct +-- !query 7 output +spark.sql.autoBroadcastJoinThreshold -1 + + +-- !query 8 +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES (1, 1), (3, 0) AS (k, v) DISTRIBUTE BY (k) +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +CREATE TEMPORARY VIEW t6 AS SELECT * FROM VALUES (1, 1), (5, 1) AS (k, v) DISTRIBUTE BY (k) +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +EXPLAIN SELECT * FROM (SELECT k AS k1 FROM t5) t5a +INNER JOIN (SELECT k AS k1 FROM t6) t6a ON t5a.k1 = t6a.k1 +-- !query 10 schema +struct +-- !query 10 output +== Physical Plan == +*SortMergeJoin [k1#x], [k1#x], Inner +:- *Sort [k1#x ASC NULLS FIRST], false, 0 +: +- *Project [k#x AS k1#x] +: +- Exchange hashpartitioning(k#x, 200) +: +- LocalTableScan [k#x] ++- *Sort [k1#x ASC NULLS FIRST], false, 0 + +- *Project [k#x AS k1#x] + +- Exchange hashpartitioning(k#x, 200) + +- LocalTableScan [k#x]