diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 01e40e64a3e8..5e14a0854d21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2616,8 +2616,8 @@ object EliminateUnions extends Rule[LogicalPlan] { * rule can't work for those parameters. */ object CleanupAliases extends Rule[LogicalPlan] { - private def trimAliases(e: Expression): Expression = { - e.transformDown { + private[spark] def trimAliases(e: Expression): Expression = { + e.transformUp { case Alias(child, _) => child case MultiAlias(child, _) => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 17e1cb416fc8..b740f3d6e835 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -184,6 +185,11 @@ trait Partitioning { case AllTuples => numPartitions == 1 case _ => false } + + /** + * Returns a version of this [[Partitioning]] amended by the invalid [[Attribute]]. + */ + private[spark] def pruneInvalidAttribute(invalidAttr: Attribute): Partitioning = this } case class UnknownPartitioning(numPartitions: Int) extends Partitioning @@ -235,6 +241,21 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * than numPartitions) based on hashing expressions. */ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + + /** + * If the HashPartitioning contains an attribute which is not present in the output expressions, + * the returned partitioning in `UnknownPartitioning` instead of the `HashPartitioning` of the + * remaining attributes which is wrong. + * Eg. `HashPartitioning('a, 'b)` with output expressions `'a as 'a1`, should produce + * `UnknownPartitioning` instead of `HashPartitioning('a1)` + */ + override private[spark] def pruneInvalidAttribute(invalidAttr: Attribute): Partitioning = { + if (this.references.contains(invalidAttr)) { + UnknownPartitioning(numPartitions) + } else { + this + } + } } /** @@ -284,6 +305,19 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } } } + + /** + * Returns `UnknownPartitioning` if the first ordering expressions is not valid anymore, + * otherwise it performs no modification because pruning the invalid expressions may cause + * errors when comparing with `ClusteredDistribution`s. + */ + override private[spark] def pruneInvalidAttribute(invalidAttr: Attribute): Partitioning = { + if (ordering.headOption.forall(_.references.contains(invalidAttr))) { + UnknownPartitioning(numPartitions) + } else { + this + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala new file mode 100644 index 000000000000..505d64bded30 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.CleanupAliases +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning} + +/** + * Trait for plans which can produce an output partitioned by aliased attributes of their child. + * It rewrites the partitioning attributes of the child with the corresponding new ones which are + * exposed in the output of this plan. It can avoid the presence of redundant shuffles in queries + * caused by the rename of an attribute among the partitioning ones, eg. + * + * spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1") + * spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df2") + * sql("set spark.sql.autoBroadcastJoinThreshold=-1") + * sql(""" + * SELECT * FROM + * (SELECT key AS k from df1) t1 + * INNER JOIN + * (SELECT key AS k from df2) t2 + * ON t1.k = t2.k + * """).explain + * + * == Physical Plan == + * *SortMergeJoin [k#21L], [k#22L], Inner + * :- *Sort [k#21L ASC NULLS FIRST], false, 0 + * : +- Exchange hashpartitioning(k#21L, 200) // <--- Unnecessary shuffle operation + * : +- *Project [key#2L AS k#21L] + * : +- Exchange hashpartitioning(key#2L, 200) + * : +- *Project [id#0L AS key#2L] + * : +- *Range (0, 10, step=1, splits=Some(4)) + * +- *(4) Sort [k#22L ASC NULLS FIRST], false, 0 + * +- *(4) Project [key#8L AS k#22L] + * +- ReusedExchange [key#8L], Exchange hashpartitioning(key#2L, 200) + */ +trait AliasAwareOutputPartitioning extends UnaryExecNode { + + /** + * `Seq` of `Expression`s which define the ouput of the node. + */ + protected def outputExpressions: Seq[NamedExpression] + + /** + * Returns the valid `Partitioning`s for the node w.r.t its output and its expressions. + */ + final override def outputPartitioning: Partitioning = { + child.outputPartitioning match { + case partitioning: Expression => + // Creates a sequence of tuples where the first element is an `Attribute` referenced in the + // partitioning expression of the child and the second is a sequence of all its aliased + // occurrences in the node output. If there is no occurrence of an attribute in the output, + // the second element of the tuple for it will be an empty `Seq`. If the attribute, + // instead, is only present as is in the output, there will be no entry for it. + // Eg. if the partitioning is RangePartitioning('a) and the node output is "a, 'a as a1, + // a' as a2", then exprToEquiv will contain the tuple ('a, Seq('a, 'a as a1, 'a as a2)). + val exprToEquiv = partitioning.references.map { attr => + attr -> outputExpressions.filter(e => + CleanupAliases.trimAliases(e).semanticEquals(attr)) + }.filterNot { case (attr, exprs) => + exprs.size == 1 && exprs.forall(_ == attr) + } + val initValue = partitioning match { + case PartitioningCollection(partitionings) => partitionings + case other => Seq(other) + } + // Replace all the aliased expressions detected earlier with all their corresponding + // occurrences. This may produce many valid partitioning expressions from a single one. + // Eg. in the example above, this would produce a `Seq` of 3 `RangePartitioning`, namely: + // `RangePartitioning('a)`, `RangePartitioning('a1)`, `RangePartitioning('a2)`. + val validPartitionings = exprToEquiv.foldLeft(initValue) { + case (partitionings, (toReplace, equivalents)) => + if (equivalents.isEmpty) { + // Remove from the partitioning expression the attribute which is not present in the + // node output + partitionings.map(_.pruneInvalidAttribute(toReplace)) + } else { + partitionings.flatMap { + case p: Expression if p.references.contains(toReplace) => + equivalents.map { equiv => + p.transformDown { + case e if e == toReplace => equiv.toAttribute + }.asInstanceOf[Partitioning] + } + case other => Seq(other) + } + } + }.distinct + if (validPartitionings.size == 1) { + validPartitionings.head + } else { + validPartitionings.filterNot(_.isInstanceOf[UnknownPartitioning]) match { + case Seq() => PartitioningCollection(validPartitionings) + case Seq(knownPartitioning) => knownPartitioning + case knownPartitionings => PartitioningCollection(knownPartitionings) + } + + } + case other => other + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 25ff6584360e..390569f89c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -50,7 +50,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen { + extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -72,7 +72,7 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputExpressions: Seq[NamedExpression] = resultExpressions override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ @@ -90,11 +90,15 @@ case class HashAggregateExec( // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { - sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { - case null | "" => None - case fallbackStartsAt => - val splits = fallbackStartsAt.split(",").map(_.trim) - Some((splits.head.toInt, splits.last.toInt)) + if (Utils.isTesting && sqlContext == null) { + None + } else { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => + val splits = fallbackStartsAt.split(",").map(_.trim) + Some((splits.head.toInt, splits.last.toInt)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 151da241144b..1977b7f30c0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -97,7 +97,7 @@ case class ObjectHashAggregateExec( } } - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputExpressions: Seq[NamedExpression] = resultExpressions protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 7ab6ecc08a7b..dbd698345d8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -66,7 +66,7 @@ case class SortAggregateExec( groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputExpressions: Seq[NamedExpression] = resultExpressions override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 7204548181f6..b022523461b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -79,7 +79,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputExpressions: Seq[NamedExpression] = projectList } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dd7c38011bc9..89d04494c0b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -25,7 +25,9 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} @@ -1595,6 +1597,38 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(Row("Amsterdam"))) } + test("SPARK-25951: avoid redundant shuffle on rename") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val N = 10 + val t1 = spark.range(N).selectExpr("floor(id/4) as k1") + val t2 = spark.range(N).selectExpr("floor(id/4) as k2") + + val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1")) + val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3") + val finalPlan = agg1.join(agg2, $"k1" === $"k3") + val exchanges = finalPlan.queryExecution.executedPlan.collect { + case se: ShuffleExchangeExec => se + } + assert(exchanges.size == 2) + assert(!exchanges.exists(_.newPartitioning match { + case HashPartitioning(Seq(a: AttributeReference), _) => a.name == "k3" + case _ => false + })) + + // In this case the requirement is not satisfied + val agg3 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumn("k3", $"k2" + 1) + val finalPlan2 = agg1.join(agg3, $"k1" === $"k3") + val exchanges2 = finalPlan2.queryExecution.executedPlan.collect { + case se: ShuffleExchangeExec => se + } + assert(exchanges2.size == 3) + assert(exchanges2.exists(_.newPartitioning match { + case HashPartitioning(Seq(a: AttributeReference), _) => a.name == "k3" + case _ => false + })) + } + } + test("SPARK-24762: Enable top-level Option of Product encoders") { val data = Seq(Some((1, "a")), Some((2, "b")), None) val ds = data.toDS() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PartitioningSuite.scala new file mode 100644 index 000000000000..945fce16f863 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PartitioningSuite.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.types.IntegerType + +class PartitioningSuite extends SparkFunSuite { + + private val attr1 = AttributeReference("attr1", IntegerType)() + private val attr2 = AttributeReference("attr2", IntegerType)() + private val aliasedAttr1 = Alias(attr1, "alias_attr1")() + private val aliasedAttr2 = Alias(attr2, "alias_attr2")() + private val aliasedAttr1Twice = Alias(Alias(attr1, "alias_attr1")(), "alias_attr1_2")() + + private val planHashPartitioned1Attr = PartitionedSparkPlan( + output = Seq(attr1), outputPartitioning = HashPartitioning(Seq(attr1), 10)) + private val planHashPartitioned2Attr = PartitionedSparkPlan( + output = Seq(attr1, attr2), outputPartitioning = HashPartitioning(Seq(attr1, attr2), 10)) + private val planRangePartitioned1Attr = PartitionedSparkPlan( + output = Seq(attr1), outputPartitioning = simpleRangePartitioning(Seq(attr1), 10)) + private val planRangePartitioned2Attr = PartitionedSparkPlan( + output = Seq(attr1, attr2), + outputPartitioning = simpleRangePartitioning(Seq(attr1, attr2), 10)) + + def testPartitioning( + outputExpressions: Seq[NamedExpression], + inputPlan: SparkPlan, + expectedPartitioning: Partitioning): Unit = { + testProjectPartitioning(outputExpressions, inputPlan, expectedPartitioning) + testAggregatePartitioning(outputExpressions, inputPlan, expectedPartitioning) + } + + def testProjectPartitioning( + projectList: Seq[NamedExpression], + inputPlan: SparkPlan, + expectedPartitioning: Partitioning): Unit = { + assert(ProjectExec(projectList, inputPlan).outputPartitioning == expectedPartitioning) + } + + def testAggregatePartitioning( + groupingExprs: Seq[NamedExpression], + inputPlan: SparkPlan, + expectedPartitioning: Partitioning): Unit = { + val hashAgg = HashAggregateExec(requiredChildDistributionExpressions = None, + groupingExpressions = groupingExprs, + aggregateExpressions = Seq.empty, + aggregateAttributes = Seq.empty, + initialInputBufferOffset = 0, + resultExpressions = groupingExprs, + child = inputPlan) + val sortAgg = SortAggregateExec(requiredChildDistributionExpressions = None, + groupingExpressions = groupingExprs, + aggregateExpressions = Seq.empty, + aggregateAttributes = Seq.empty, + initialInputBufferOffset = 0, + resultExpressions = groupingExprs, + child = inputPlan) + val objAgg = ObjectHashAggregateExec(requiredChildDistributionExpressions = None, + groupingExpressions = groupingExprs, + aggregateExpressions = Seq.empty, + aggregateAttributes = Seq.empty, + initialInputBufferOffset = 0, + resultExpressions = groupingExprs, + child = inputPlan) + assert(hashAgg.outputPartitioning == expectedPartitioning) + assert(sortAgg.outputPartitioning == expectedPartitioning) + assert(objAgg.outputPartitioning == expectedPartitioning) + } + + def simpleRangePartitioning(exprs: Seq[Expression], numPartitions: Int): RangePartitioning = { + RangePartitioning(exprs.map(e => SortOrder(e, Ascending)), numPartitions) + } + + test("HashPartitioning with simple attribute rename") { + testPartitioning( + Seq(aliasedAttr1), + planHashPartitioned1Attr, + HashPartitioning(Seq(aliasedAttr1.toAttribute), 10)) + testPartitioning( + Seq(aliasedAttr1Twice), + planHashPartitioned1Attr, + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute), 10)) + + testPartitioning( + Seq(aliasedAttr1, attr2), + planHashPartitioned2Attr, + HashPartitioning(Seq(aliasedAttr1.toAttribute, attr2), 10)) + testPartitioning( + Seq(aliasedAttr1Twice, attr2), + planHashPartitioned2Attr, + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute, attr2), 10)) + + testPartitioning( + Seq(aliasedAttr1, aliasedAttr2), + planHashPartitioned2Attr, + HashPartitioning(Seq(aliasedAttr1.toAttribute, aliasedAttr2.toAttribute), 10)) + testPartitioning( + Seq(aliasedAttr1Twice, aliasedAttr2), + planHashPartitioned2Attr, + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute, aliasedAttr2.toAttribute), 10)) + } + + test("HashPartitioning with double attribute rename") { + testPartitioning( + Seq(aliasedAttr1, aliasedAttr1Twice), + planHashPartitioned1Attr, + PartitioningCollection(Seq( + HashPartitioning(Seq(aliasedAttr1.toAttribute), 10), + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute), 10)))) + testPartitioning( + Seq(aliasedAttr1, aliasedAttr1Twice, attr2), + planHashPartitioned2Attr, + PartitioningCollection(Seq( + HashPartitioning(Seq(aliasedAttr1.toAttribute, attr2), 10), + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute, attr2), 10)))) + testPartitioning( + Seq(aliasedAttr1, aliasedAttr1Twice, attr2, aliasedAttr2), + planHashPartitioned2Attr, + PartitioningCollection(Seq( + HashPartitioning(Seq(aliasedAttr1.toAttribute, attr2), 10), + HashPartitioning(Seq(aliasedAttr1.toAttribute, aliasedAttr2.toAttribute), 10), + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute, attr2), 10), + HashPartitioning(Seq(aliasedAttr1Twice.toAttribute, aliasedAttr2.toAttribute), 10)))) + } + + test("HashPartitioning without attribute in output") { + testPartitioning( + Seq(attr2), + planHashPartitioned1Attr, + UnknownPartitioning(10)) + testPartitioning( + Seq(attr1), + planHashPartitioned2Attr, + UnknownPartitioning(10)) + } + + test("HashPartitioning without renaming") { + testPartitioning( + Seq(attr1), + planHashPartitioned1Attr, + HashPartitioning(Seq(attr1), 10)) + testPartitioning( + Seq(attr1, attr2), + planHashPartitioned2Attr, + HashPartitioning(Seq(attr1, attr2), 10)) + } + + test("RangePartitioning with simple attribute rename") { + testPartitioning( + Seq(aliasedAttr1), + planRangePartitioned1Attr, + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute), 10)) + testPartitioning( + Seq(aliasedAttr1Twice), + planRangePartitioned1Attr, + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute), 10)) + + testPartitioning( + Seq(aliasedAttr1, attr2), + planRangePartitioned2Attr, + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute, attr2), 10)) + testPartitioning( + Seq(aliasedAttr1Twice, attr2), + planRangePartitioned2Attr, + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute, attr2), 10)) + + testPartitioning( + Seq(aliasedAttr1, aliasedAttr2), + planRangePartitioned2Attr, + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute, aliasedAttr2.toAttribute), 10)) + testPartitioning( + Seq(aliasedAttr1Twice, aliasedAttr2), + planRangePartitioned2Attr, + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute, aliasedAttr2.toAttribute), 10)) + } + + test("RangePartitioning with double attribute rename") { + testPartitioning( + Seq(aliasedAttr1, aliasedAttr1Twice), + planRangePartitioned1Attr, + PartitioningCollection(Seq( + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute), 10), + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute), 10)))) + testPartitioning( + Seq(aliasedAttr1, aliasedAttr1Twice, attr2), + planRangePartitioned2Attr, + PartitioningCollection(Seq( + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute, attr2), 10), + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute, attr2), 10)))) + testPartitioning( + Seq(aliasedAttr1, aliasedAttr1Twice, attr2, aliasedAttr2), + planRangePartitioned2Attr, + PartitioningCollection(Seq( + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute, attr2), 10), + simpleRangePartitioning(Seq(aliasedAttr1.toAttribute, aliasedAttr2.toAttribute), 10), + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute, attr2), 10), + simpleRangePartitioning(Seq(aliasedAttr1Twice.toAttribute, aliasedAttr2.toAttribute), 10)))) + } + + test("RangePartitioning without attribute in output") { + testPartitioning( + Seq(attr2), + planRangePartitioned2Attr, + UnknownPartitioning(10)) + testPartitioning( + Seq(attr1), + planRangePartitioned2Attr, + simpleRangePartitioning(Seq(attr1, attr2), 10)) + } + + test("RangePartitioning without renaming") { + testPartitioning( + Seq(attr1), + planRangePartitioned1Attr, + simpleRangePartitioning(Seq(attr1), 10)) + testPartitioning( + Seq(attr1, attr2), + planRangePartitioned2Attr, + simpleRangePartitioning(Seq(attr1, attr2), 10)) + } +} + +private case class PartitionedSparkPlan( + override val output: Seq[Attribute] = Seq.empty, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val children: Seq[SparkPlan] = Nil) extends SparkPlan { + override protected def doExecute() = throw new UnsupportedOperationException +}