From 2f647e13d98ee4738a413c8941d193d1dba5d4a4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 24 Mar 2017 01:12:28 +0900 Subject: [PATCH 1/3] Update partitioning info. in ProjectExec when having aliases --- .../execution/aggregate/AggregateExec.scala | 75 +++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 47 ++++++------ .../aggregate/SortAggregateExec.scala | 47 ++++++------ .../execution/basicPhysicalOperators.scala | 17 ++++- .../exchange/EnsureRequirements.scala | 58 +++++++++++++- .../resources/sql-tests/inputs/group-by.sql | 6 ++ .../resources/sql-tests/inputs/inner-join.sql | 7 ++ .../sql-tests/results/group-by.sql.out | 46 +++++++++++- .../sql-tests/results/inner-join.sql.out | 44 ++++++++++- 9 files changed, 293 insertions(+), 54 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 0000000000000..3868af0498ac9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * A base class for aggregate implementation. + */ +abstract class AggregateExec extends UnaryExecNode { + + def requiredChildDistributionExpressions: Option[Seq[Expression]] + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def initialInputBufferOffset: Int + def resultExpressions: Seq[NamedExpression] + def child: SparkPlan + def partitioning: Option[Partitioning] + + protected[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) ++ + partitioning.flatMap { + case e: Expression => Some(e.references) + case _ => None + }.getOrElse(AttributeSet.empty) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def outputPartitioning: Partitioning = partitioning.getOrElse(child.outputPartitioning) + + def copy( + requiredChildDistributionExpressions: Option[Seq[Expression]] = + requiredChildDistributionExpressions, + groupingExpressions: Seq[NamedExpression] = groupingExpressions, + aggregateExpressions: Seq[AggregateExpression] = aggregateExpressions, + aggregateAttributes: Seq[Attribute] = aggregateAttributes, + initialInputBufferOffset: Int = initialInputBufferOffset, + resultExpressions: Seq[NamedExpression] = resultExpressions, + child: SparkPlan = child, + partitioning: Option[Partitioning] = partitioning): AggregateExec +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2cac0cfce28de..8b883abc99329 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -44,12 +44,9 @@ case class HashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryExecNode with CodegenSupport { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } + child: SparkPlan, + partitioning: Option[Partitioning] = None) + extends AggregateExec with CodegenSupport { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -64,23 +61,6 @@ case class HashAggregateExec( "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { @@ -914,6 +894,27 @@ case class HashAggregateExec( """ } + def copy( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan, + partitioning: Option[Partitioning] = partitioning): AggregateExec = { + new HashAggregateExec( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child, + partitioning + ) + } + override def verboseString: String = toString(verbose = true) override def simpleString: String = toString(verbose = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..81eee47db909f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -37,37 +37,17 @@ case class SortAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryExecNode { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) + child: SparkPlan, + partitioning: Option[Partitioning] = None) + extends AggregateExec { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) } @@ -107,6 +87,27 @@ case class SortAggregateExec( } } + def copy( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan, + partitioning: Option[Partitioning] = partitioning): AggregateExec = { + new SortAggregateExec( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child, + partitioning + ) + } + override def simpleString: String = toString(verbose = false) override def verboseString: String = toString(verbose = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 9434ceb7cd16c..e7ed5af1788fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ -case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) +case class ProjectExec( + projectList: Seq[NamedExpression], + child: SparkPlan, + partitioning: Option[Partitioning] = None) extends UnaryExecNode with CodegenSupport { + override def producedAttributes: AttributeSet = partitioning.flatMap { + case e: Expression => Some(e.references) + case _ => None + }.getOrElse(AttributeSet.empty) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -77,7 +85,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputPartitioning: Partitioning = partitioning.getOrElse(child.outputPartitioning) + + override def simpleString: String = + s"Project ${Utils.truncatedString(projectList, "[", ", ", "]")}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d2d5011bbcb97..a8eaedc12e497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.aggregate.AggregateExec import org.apache.spark.sql.internal.SQLConf /** @@ -293,6 +293,58 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } + private def updatePartitioningByAliases(exprs: Seq[NamedExpression], partioning: Partitioning) + : Partitioning = { + val aliasSeq = exprs.flatMap(_.collectFirst { + case a @ Alias(child, _) => (child, a.toAttribute) + }) + + def maybeReplaceExpr(e: Expression): Expression = aliasSeq.find { + case (c, _) => c.semanticEquals(e) + }.map(_._2).getOrElse(e) + + def maybeUpdatePartitioningExprs(p: Partitioning): Partitioning = p match { + case hash @ HashPartitioning(exprs, _) => + hash.copy(expressions = exprs.map(maybeReplaceExpr)) + case range @ RangePartitioning(ordering, _) => + range.copy(ordering = ordering.map { order => + order.copy( + child = maybeReplaceExpr(order.child), + sameOrderExpressions = order.sameOrderExpressions.map(maybeReplaceExpr) + ) + }) + case _ => p + } + + partioning match { + case pc @ PartitioningCollection(ps) => + pc.copy(partitionings = ps.map(maybeUpdatePartitioningExprs)) + case _ => + maybeUpdatePartitioningExprs(partioning) + } + } + + private def hasAlias(exprs: Seq[NamedExpression]): Boolean = { + exprs.exists(_.collectFirst { case _: Alias => true }.isDefined) + } + + // If children have alias names, `outputPartitioning`s ignore the alias expressions and + // `ensureDistributionAndOrdering` inserts unnecessary shuffles. + // To solve this, we update `outputPartitioning`s by using aliases here. + private def maybeUpdateChildrenOutputPartitioning(operator: SparkPlan): SparkPlan = { + val newChildren = operator.children.map { + case proj @ ProjectExec(projectList, _, _) if hasAlias(projectList) => + val newOutputPartitionig = updatePartitioningByAliases(projectList, proj.outputPartitioning) + proj.copy(partitioning = Some(newOutputPartitionig)) + case agg: AggregateExec if hasAlias(agg.resultExpressions) => + val newOutputPartitionig = + updatePartitioningByAliases(agg.resultExpressions, agg.outputPartitioning) + agg.copy(partitioning = Some(newOutputPartitionig)) + case p => p + } + operator.withNewChildren(newChildren) + } + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => @@ -301,6 +353,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case _ => operator } case operator: SparkPlan => - ensureDistributionAndOrdering(reorderJoinPredicates(operator)) + ensureDistributionAndOrdering(maybeUpdateChildrenOutputPartitioning(operator)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 2c18d6aaabdba..268b734d4b696 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -73,3 +73,9 @@ where b.z != b.z; -- SPARK-24369 multiple distinct aggregations having the same argument set SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); + +-- SPARK-19981 Correctly resolve partitioning when output has aliases +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 'a'), (1, 'b') AS (a, b) DISTRIBUTE BY a; +EXPLAIN SELECT k, MAX(b) FROM (SELECT a AS k, b FROM t1) t GROUP BY k; +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 2), (0, 3) AS (a, b); +EXPLAIN SELECT k, COUNT(v) FROM (SELECT a AS k, MAX(b) AS v FROM t2 GROUP BY a) t GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql index 38739cb950582..2c5f5e01e5984 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql @@ -15,3 +15,10 @@ SELECT a, 'b' AS tag FROM t4; -- SPARK-19766 Constant alias columns in INNER JOIN should not be folded by FoldablePropagation rule SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag; + +-- SPARK-19981 Correctly resolve partitioning when output has aliases +SET spark.sql.autoBroadcastJoinThreshold = -1; +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES (1, 1), (3, 0) AS (k, v) DISTRIBUTE BY (k); +CREATE TEMPORARY VIEW t6 AS SELECT * FROM VALUES (1, 1), (5, 1) AS (k, v) DISTRIBUTE BY (k); +EXPLAIN SELECT * FROM (SELECT k AS k1 FROM t5) t5a +INNER JOIN (SELECT k AS k1 FROM t6) t6a ON t5a.k1 = t6a.k1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 581aa1754ce14..37c809104ddbe 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 31 -- !query 0 @@ -250,3 +250,47 @@ SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) struct -- !query 26 output 1.0 1.0 3 + + +-- !query 27 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 'a'), (1, 'b') AS (a, b) DISTRIBUTE BY a +-- !query 27 schema +struct<> +-- !query 27 output + + + +-- !query 28 +EXPLAIN SELECT k, MAX(b) FROM (SELECT a AS k, b FROM t1) t GROUP BY k +-- !query 28 schema +struct +-- !query 28 output +== Physical Plan == +SortAggregate(key=[k#x], functions=[max(b#x)]) ++- SortAggregate(key=[k#x], functions=[partial_max(b#x)]) + +- *Sort [k#x ASC NULLS FIRST], false, 0 + +- *Project [a#x AS k#x, b#x] + +- Exchange hashpartitioning(a#x, 200) + +- LocalTableScan [a#x, b#x] + + +-- !query 29 +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 2), (0, 3) AS (a, b) +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +EXPLAIN SELECT k, COUNT(v) FROM (SELECT a AS k, MAX(b) AS v FROM t2 GROUP BY a) t GROUP BY k +-- !query 30 schema +struct +-- !query 30 output +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[count(v#x)]) ++- *HashAggregate(keys=[k#x], functions=[partial_count(v#x)]) + +- *HashAggregate(keys=[a#x], functions=[max(b#x)]) + +- Exchange hashpartitioning(a#x, 200) + +- *HashAggregate(keys=[a#x], functions=[partial_max(b#x)]) + +- LocalTableScan [a#x, b#x] diff --git a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out index 8d56ebe9fd3b4..38c3e20bdb4d1 100644 --- a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 11 -- !query 0 @@ -65,3 +65,45 @@ struct 1 a 1 b 1 b + + +-- !query 7 +SET spark.sql.autoBroadcastJoinThreshold = -1 +-- !query 7 schema +struct +-- !query 7 output +spark.sql.autoBroadcastJoinThreshold -1 + + +-- !query 8 +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES (1, 1), (3, 0) AS (k, v) DISTRIBUTE BY (k) +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +CREATE TEMPORARY VIEW t6 AS SELECT * FROM VALUES (1, 1), (5, 1) AS (k, v) DISTRIBUTE BY (k) +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +EXPLAIN SELECT * FROM (SELECT k AS k1 FROM t5) t5a +INNER JOIN (SELECT k AS k1 FROM t6) t6a ON t5a.k1 = t6a.k1 +-- !query 10 schema +struct +-- !query 10 output +== Physical Plan == +*SortMergeJoin [k1#x], [k1#x], Inner +:- *Sort [k1#x ASC NULLS FIRST], false, 0 +: +- *Project [k#x AS k1#x] +: +- Exchange hashpartitioning(k#x, 200) +: +- LocalTableScan [k#x] ++- *Sort [k1#x ASC NULLS FIRST], false, 0 + +- *Project [k#x AS k1#x] + +- Exchange hashpartitioning(k#x, 200) + +- LocalTableScan [k#x] From ceb38062ae2c8c25506e4a8148a353b968acd670 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Aug 2018 13:38:38 +0900 Subject: [PATCH 2/3] Another solution --- .../execution/aggregate/AggregateExec.scala | 75 ---------- .../aggregate/HashAggregateExec.scala | 47 +++--- .../aggregate/SortAggregateExec.scala | 47 +++--- .../execution/basicPhysicalOperators.scala | 17 +-- .../exchange/EnsureRequirements.scala | 135 +++++++++++------- 5 files changed, 129 insertions(+), 192 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala deleted file mode 100644 index 3868af0498ac9..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.UnaryExecNode - -/** - * A base class for aggregate implementation. - */ -abstract class AggregateExec extends UnaryExecNode { - - def requiredChildDistributionExpressions: Option[Seq[Expression]] - def groupingExpressions: Seq[NamedExpression] - def aggregateExpressions: Seq[AggregateExpression] - def aggregateAttributes: Seq[Attribute] - def initialInputBufferOffset: Int - def resultExpressions: Seq[NamedExpression] - def child: SparkPlan - def partitioning: Option[Partitioning] - - protected[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) ++ - partitioning.flatMap { - case e: Expression => Some(e.references) - case _ => None - }.getOrElse(AttributeSet.empty) - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def outputPartitioning: Partitioning = partitioning.getOrElse(child.outputPartitioning) - - def copy( - requiredChildDistributionExpressions: Option[Seq[Expression]] = - requiredChildDistributionExpressions, - groupingExpressions: Seq[NamedExpression] = groupingExpressions, - aggregateExpressions: Seq[AggregateExpression] = aggregateExpressions, - aggregateAttributes: Seq[Attribute] = aggregateAttributes, - initialInputBufferOffset: Int = initialInputBufferOffset, - resultExpressions: Seq[NamedExpression] = resultExpressions, - child: SparkPlan = child, - partitioning: Option[Partitioning] = partitioning): AggregateExec -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8b883abc99329..2cac0cfce28de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -44,9 +44,12 @@ case class HashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan, - partitioning: Option[Partitioning] = None) - extends AggregateExec with CodegenSupport { + child: SparkPlan) + extends UnaryExecNode with CodegenSupport { + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -61,6 +64,23 @@ case class HashAggregateExec( "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { @@ -894,27 +914,6 @@ case class HashAggregateExec( """ } - def copy( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - aggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan, - partitioning: Option[Partitioning] = partitioning): AggregateExec = { - new HashAggregateExec( - requiredChildDistributionExpressions, - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - initialInputBufferOffset, - resultExpressions, - child, - partitioning - ) - } - override def verboseString: String = toString(verbose = true) override def simpleString: String = toString(verbose = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 81eee47db909f..fc87de2c52e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -37,17 +37,37 @@ case class SortAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan, - partitioning: Option[Partitioning] = None) - extends AggregateExec { + child: SparkPlan) + extends UnaryExecNode { + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) } @@ -87,27 +107,6 @@ case class SortAggregateExec( } } - def copy( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - aggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan, - partitioning: Option[Partitioning] = partitioning): AggregateExec = { - new SortAggregateExec( - requiredChildDistributionExpressions, - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - initialInputBufferOffset, - resultExpressions, - child, - partitioning - ) - } - override def simpleString: String = toString(verbose = false) override def verboseString: String = toString(verbose = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e7ed5af1788fa..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -28,21 +28,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ThreadUtils import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ -case class ProjectExec( - projectList: Seq[NamedExpression], - child: SparkPlan, - partitioning: Option[Partitioning] = None) +case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with CodegenSupport { - override def producedAttributes: AttributeSet = partitioning.flatMap { - case e: Expression => Some(e.references) - case _ => None - }.getOrElse(AttributeSet.empty) - override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -85,10 +77,7 @@ case class ProjectExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = partitioning.getOrElse(child.outputPartitioning) - - override def simpleString: String = - s"Project ${Utils.truncatedString(projectList, "[", ", ", "]")}" + override def outputPartitioning: Partitioning = child.outputPartitioning } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index a8eaedc12e497..6139d6973986b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.aggregate.AggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -138,6 +138,81 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { withCoordinator } + private def resolveOutputPartitioningByAliases( + exprs: Seq[NamedExpression], + partitioning: Partitioning): Partitioning = { + val aliasSeq = exprs.flatMap(_.collectFirst { + case a @ Alias(child, _) => (child, a.toAttribute) + }) + def mayReplaceExprWithAlias(e: Expression): Expression = { + aliasSeq.find { case (c, _) => c.semanticEquals(e) }.map(_._2).getOrElse(e) + } + def mayReplacePartitioningExprsWithAliases(p: Partitioning): Partitioning = p match { + case hash @ HashPartitioning(exprs, _) => + hash.copy(expressions = exprs.map(mayReplaceExprWithAlias)) + case range @ RangePartitioning(ordering, _) => + range.copy(ordering = ordering.map { order => + order.copy( + child = mayReplaceExprWithAlias(order.child), + sameOrderExpressions = order.sameOrderExpressions.map(mayReplaceExprWithAlias) + ) + }) + case _ => p + } + + partitioning match { + case pc @ PartitioningCollection(ps) => + pc.copy(partitionings = ps.map(mayReplacePartitioningExprsWithAliases)) + case _ => + mayReplacePartitioningExprsWithAliases(partitioning) + } + } + + // If projects and aggregates have aliases in output expressions, we should respect + // these aliases so as to check if the operators satisfy their output distribution requirements. + // If we don't respect aliases, this rule wrongly adds shuffle operations, e.g., + // + // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df1") + // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df2") + // sql(""" + // SELECT * FROM + // (SELECT key AS k from df1) t1 + // INNER JOIN + // (SELECT key AS k from df2) t2 + // ON t1.k = t2.k + // """).explain + // + // == Physical Plan == + // *SortMergeJoin [k#56L], [k#57L], Inner + // :- *Sort [k#56L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(k#56L, 200) // <--- Unnecessary shuffle operation + // : +- *Project [key#39L AS k#56L] + // : +- Exchange hashpartitioning(key#39L, 200) + // : +- *Project [id#36L AS key#39L] + // : +- *Range (0, 10, step=1, splits=Some(4)) + // +- *Sort [k#57L ASC NULLS FIRST], false, 0 + // +- ReusedExchange [k#57L], Exchange hashpartitioning(k#56L, 200) + private def isSatisfiedByAliasedOutputPartitioning( + child: SparkPlan, + distribution: Distribution): Boolean = { + val outputExprs = child match { + case ProjectExec(projectList, _) => projectList + case HashAggregateExec(_, _, _, _, _, resultExprs, _) => resultExprs + case ObjectHashAggregateExec(_, _, _, _, _, resultExprs, _) => resultExprs + case SortAggregateExec(_, _, _, _, _, resultExprs, _) => resultExprs + case _ => Seq.empty + } + def hasAlias(exprs: Seq[NamedExpression]) = + exprs.exists(_.collectFirst { case _: Alias => true }.isDefined) + if (hasAlias(outputExprs)) { + val newOutputPartitioning = + resolveOutputPartitioningByAliases(outputExprs, child.outputPartitioning) + newOutputPartitioning.satisfies(distribution) + } else { + false + } + } + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering @@ -149,6 +224,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child + case (child, distribution) if isSatisfiedByAliasedOutputPartitioning(child, distribution) => + child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => @@ -293,58 +370,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } - private def updatePartitioningByAliases(exprs: Seq[NamedExpression], partioning: Partitioning) - : Partitioning = { - val aliasSeq = exprs.flatMap(_.collectFirst { - case a @ Alias(child, _) => (child, a.toAttribute) - }) - - def maybeReplaceExpr(e: Expression): Expression = aliasSeq.find { - case (c, _) => c.semanticEquals(e) - }.map(_._2).getOrElse(e) - - def maybeUpdatePartitioningExprs(p: Partitioning): Partitioning = p match { - case hash @ HashPartitioning(exprs, _) => - hash.copy(expressions = exprs.map(maybeReplaceExpr)) - case range @ RangePartitioning(ordering, _) => - range.copy(ordering = ordering.map { order => - order.copy( - child = maybeReplaceExpr(order.child), - sameOrderExpressions = order.sameOrderExpressions.map(maybeReplaceExpr) - ) - }) - case _ => p - } - - partioning match { - case pc @ PartitioningCollection(ps) => - pc.copy(partitionings = ps.map(maybeUpdatePartitioningExprs)) - case _ => - maybeUpdatePartitioningExprs(partioning) - } - } - - private def hasAlias(exprs: Seq[NamedExpression]): Boolean = { - exprs.exists(_.collectFirst { case _: Alias => true }.isDefined) - } - - // If children have alias names, `outputPartitioning`s ignore the alias expressions and - // `ensureDistributionAndOrdering` inserts unnecessary shuffles. - // To solve this, we update `outputPartitioning`s by using aliases here. - private def maybeUpdateChildrenOutputPartitioning(operator: SparkPlan): SparkPlan = { - val newChildren = operator.children.map { - case proj @ ProjectExec(projectList, _, _) if hasAlias(projectList) => - val newOutputPartitionig = updatePartitioningByAliases(projectList, proj.outputPartitioning) - proj.copy(partitioning = Some(newOutputPartitionig)) - case agg: AggregateExec if hasAlias(agg.resultExpressions) => - val newOutputPartitionig = - updatePartitioningByAliases(agg.resultExpressions, agg.outputPartitioning) - agg.copy(partitioning = Some(newOutputPartitionig)) - case p => p - } - operator.withNewChildren(newChildren) - } - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => @@ -353,6 +378,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case _ => operator } case operator: SparkPlan => - ensureDistributionAndOrdering(maybeUpdateChildrenOutputPartitioning(operator)) + ensureDistributionAndOrdering(reorderJoinPredicates(operator)) } } From 5482b1be6308ddf7e77dc25c0bdfca3ede2d61a7 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 22 Aug 2018 16:44:55 +0900 Subject: [PATCH 3/3] Make AliasAwareOutputPartitioning trait --- .../AliasAwareOutputPartitioning.scala | 89 +++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 6 +- .../aggregate/ObjectHashAggregateExec.scala | 6 +- .../aggregate/SortAggregateExec.scala | 8 +- .../execution/basicPhysicalOperators.scala | 6 +- .../exchange/EnsureRequirements.scala | 78 ---------------- 6 files changed, 102 insertions(+), 91 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala new file mode 100644 index 0000000000000..b7d501725705c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.physical._ + +trait AliasAwareOutputPartitioning extends UnaryExecNode { + + protected def outputExpressions: Seq[NamedExpression] + + // If projects and aggregates have aliases in output expressions, we should respect + // these aliases so as to check if the operators satisfy their output distribution requirements. + // If we don't respect aliases, this rule wrongly adds shuffle operations, e.g., + // + // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df1") + // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df2") + // sql(""" + // SELECT * FROM + // (SELECT key AS k from df1) t1 + // INNER JOIN + // (SELECT key AS k from df2) t2 + // ON t1.k = t2.k + // """).explain + // + // == Physical Plan == + // *SortMergeJoin [k#56L], [k#57L], Inner + // :- *Sort [k#56L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(k#56L, 200) // <--- Unnecessary shuffle operation + // : +- *Project [key#39L AS k#56L] + // : +- Exchange hashpartitioning(key#39L, 200) + // : +- *Project [id#36L AS key#39L] + // : +- *Range (0, 10, step=1, splits=Some(4)) + // +- *Sort [k#57L ASC NULLS FIRST], false, 0 + // +- ReusedExchange [k#57L], Exchange hashpartitioning(k#56L, 200) + final override def outputPartitioning: Partitioning = if (hasAlias(outputExpressions)) { + resolveOutputPartitioningByAliases(outputExpressions, child.outputPartitioning) + } else { + child.outputPartitioning + } + + private def hasAlias(exprs: Seq[NamedExpression]): Boolean = + exprs.exists(_.collectFirst { case _: Alias => true }.isDefined) + + private def resolveOutputPartitioningByAliases( + exprs: Seq[NamedExpression], + partitioning: Partitioning): Partitioning = { + val aliasSeq = exprs.flatMap(_.collectFirst { + case a @ Alias(child, _) => (child, a.toAttribute) + }) + def mayReplaceExprWithAlias(e: Expression): Expression = { + aliasSeq.find { case (c, _) => c.semanticEquals(e) }.map(_._2).getOrElse(e) + } + def mayReplacePartitioningExprsWithAliases(p: Partitioning): Partitioning = p match { + case hash @ HashPartitioning(exprs, _) => + hash.copy(expressions = exprs.map(mayReplaceExprWithAlias)) + case range @ RangePartitioning(ordering, _) => + range.copy(ordering = ordering.map { order => + order.copy( + child = mayReplaceExprWithAlias(order.child), + sameOrderExpressions = order.sameOrderExpressions.map(mayReplaceExprWithAlias) + ) + }) + case _ => p + } + + partitioning match { + case pc @ PartitioningCollection(ps) => + pc.copy(partitionings = ps.map(mayReplacePartitioningExprsWithAliases)) + case _ => + mayReplacePartitioningExprsWithAliases(partitioning) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2cac0cfce28de..885f28ae6d65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -45,7 +45,9 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning { + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -66,8 +68,6 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning - override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..aefb8e428ebda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -65,7 +65,9 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -95,8 +97,6 @@ case class ObjectHashAggregateExec( } } - override def outputPartitioning: Partitioning = child.outputPartitioning - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..3d3425d413ed7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.Utils @@ -38,7 +38,9 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -66,8 +68,6 @@ case class SortAggregateExec( groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 9434ceb7cd16c..0f9a137717754 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -33,10 +33,12 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override protected def outputExpressions: Seq[NamedExpression] = projectList + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } @@ -76,8 +78,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 6139d6973986b..fb7edc18c1046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf @@ -138,81 +137,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { withCoordinator } - private def resolveOutputPartitioningByAliases( - exprs: Seq[NamedExpression], - partitioning: Partitioning): Partitioning = { - val aliasSeq = exprs.flatMap(_.collectFirst { - case a @ Alias(child, _) => (child, a.toAttribute) - }) - def mayReplaceExprWithAlias(e: Expression): Expression = { - aliasSeq.find { case (c, _) => c.semanticEquals(e) }.map(_._2).getOrElse(e) - } - def mayReplacePartitioningExprsWithAliases(p: Partitioning): Partitioning = p match { - case hash @ HashPartitioning(exprs, _) => - hash.copy(expressions = exprs.map(mayReplaceExprWithAlias)) - case range @ RangePartitioning(ordering, _) => - range.copy(ordering = ordering.map { order => - order.copy( - child = mayReplaceExprWithAlias(order.child), - sameOrderExpressions = order.sameOrderExpressions.map(mayReplaceExprWithAlias) - ) - }) - case _ => p - } - - partitioning match { - case pc @ PartitioningCollection(ps) => - pc.copy(partitionings = ps.map(mayReplacePartitioningExprsWithAliases)) - case _ => - mayReplacePartitioningExprsWithAliases(partitioning) - } - } - - // If projects and aggregates have aliases in output expressions, we should respect - // these aliases so as to check if the operators satisfy their output distribution requirements. - // If we don't respect aliases, this rule wrongly adds shuffle operations, e.g., - // - // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df1") - // spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df2") - // sql(""" - // SELECT * FROM - // (SELECT key AS k from df1) t1 - // INNER JOIN - // (SELECT key AS k from df2) t2 - // ON t1.k = t2.k - // """).explain - // - // == Physical Plan == - // *SortMergeJoin [k#56L], [k#57L], Inner - // :- *Sort [k#56L ASC NULLS FIRST], false, 0 - // : +- Exchange hashpartitioning(k#56L, 200) // <--- Unnecessary shuffle operation - // : +- *Project [key#39L AS k#56L] - // : +- Exchange hashpartitioning(key#39L, 200) - // : +- *Project [id#36L AS key#39L] - // : +- *Range (0, 10, step=1, splits=Some(4)) - // +- *Sort [k#57L ASC NULLS FIRST], false, 0 - // +- ReusedExchange [k#57L], Exchange hashpartitioning(k#56L, 200) - private def isSatisfiedByAliasedOutputPartitioning( - child: SparkPlan, - distribution: Distribution): Boolean = { - val outputExprs = child match { - case ProjectExec(projectList, _) => projectList - case HashAggregateExec(_, _, _, _, _, resultExprs, _) => resultExprs - case ObjectHashAggregateExec(_, _, _, _, _, resultExprs, _) => resultExprs - case SortAggregateExec(_, _, _, _, _, resultExprs, _) => resultExprs - case _ => Seq.empty - } - def hasAlias(exprs: Seq[NamedExpression]) = - exprs.exists(_.collectFirst { case _: Alias => true }.isDefined) - if (hasAlias(outputExprs)) { - val newOutputPartitioning = - resolveOutputPartitioningByAliases(outputExprs, child.outputPartitioning) - newOutputPartitioning.satisfies(distribution) - } else { - false - } - } - private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering @@ -224,8 +148,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child - case (child, distribution) if isSatisfiedByAliasedOutputPartitioning(child, distribution) => - child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) =>