diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..44952442cce86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -171,6 +171,16 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other + + /** + * Returns the partitioning scheme that is valid under restriction to a given set of output + * attributes. If the partitioning is an [[Expression]] then the attributes that it depends on + * must be in the outputSet otherwise the attribute leaks. + */ + def restrict(outputSet: AttributeSet): Partitioning = this match { + case p: Expression if !p.references.subsetOf(outputSet) => UnknownPartitioning(numPartitions) + case _ => this + } } object Partitioning { @@ -356,6 +366,14 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def guarantees(other: Partitioning): Boolean = partitionings.exists(_.guarantees(other)) + override def restrict(outputSet: AttributeSet): Partitioning = { + partitionings.map(_.restrict(outputSet)).filter(!_.isInstanceOf[UnknownPartitioning]) match { + case Nil => UnknownPartitioning(numPartitions) + case singlePartitioning :: Nil => singlePartitioning + case more => PartitioningCollection(more) + } + } + override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala index 5b802ccc637dd..3112be5fe985d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, AttributeSet, InterpretedMutableProjection, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.DataTypes class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { @@ -52,4 +53,41 @@ class PartitioningSuite extends SparkFunSuite { assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } + + test("restriction of Partitioning works") { + val n = 5 + + val a1 = AttributeReference("a1", DataTypes.IntegerType)() + val a2 = AttributeReference("a2", DataTypes.IntegerType)() + val a3 = AttributeReference("a3", DataTypes.IntegerType)() + + val hashPartitioning = HashPartitioning(Seq(a1, a2), n) + + assert(hashPartitioning.restrict(AttributeSet(Seq())) === UnknownPartitioning(n)) + assert(hashPartitioning.restrict(AttributeSet(Seq(a1))) === UnknownPartitioning(n)) + assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2))) === hashPartitioning) + assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2, a3))) === hashPartitioning) + + val so1 = SortOrder(a1, Ascending) + val so2 = SortOrder(a2, Ascending) + + val rangePartitioning1 = RangePartitioning(Seq(so1), n) + val rangePartitioning2 = RangePartitioning(Seq(so1, so2), n) + + assert(rangePartitioning2.restrict(AttributeSet(Seq())) == UnknownPartitioning(n)) + assert(rangePartitioning2.restrict(AttributeSet(Seq(a1))) == UnknownPartitioning(n)) + assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2))) === rangePartitioning2) + assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2, a3))) === rangePartitioning2) + + assert(SinglePartition.restrict(AttributeSet(a1)) === SinglePartition) + + val all = Seq(hashPartitioning, rangePartitioning1, rangePartitioning2) + val partitioningCollection = PartitioningCollection(all) + + assert(partitioningCollection.restrict(AttributeSet(Seq())) == UnknownPartitioning(n)) + assert(partitioningCollection.restrict(AttributeSet(Seq(a1))) == rangePartitioning1) + assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2))) == partitioningCollection) + assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2, a3))) == partitioningCollection) + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index c7277c21cebb2..3b285063bfa35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -65,6 +65,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } + override def verboseStringWithSuffix: String = { + s"$verboseString $outputPartitioning" + } + /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SparkSession.setActiveSession(sqlContext.sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index bacb7090a70ab..6858501ca4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -273,7 +273,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "") + child.generateTreeString(depth, lastChildren, builder, verbose, "", addSuffix) } } @@ -456,7 +456,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + child.generateTreeString(depth, lastChildren, builder, verbose, "*", addSuffix) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index d77405c559c58..d6386d0ea1e4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -64,7 +64,7 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet) override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2151c339b9b87..4c44d4606c2f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -80,7 +80,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 453052a8ce191..8c7f66ff46446 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,10 @@ import scala.language.existentials import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -38,6 +42,70 @@ class JoinSuite extends QueryTest with SharedSQLContext { df.queryExecution.optimizedPlan.stats.sizeInBytes } + test("SPARK-16683 Repeated joins to same table can leak attributes via partitioning") { + val hier = sqlContext.sparkSession.sparkContext.parallelize(Seq( + ("A10", "A1"), + ("A11", "A1"), + ("A20", "A2"), + ("A21", "A2"), + ("B10", "B1"), + ("B11", "B1"), + ("B20", "B2"), + ("B21", "B2"), + ("A1", "A"), + ("A2", "A"), + ("B1", "B"), + ("B2", "B") + )).toDF("son", "parent").cache() // passes if cache is removed but with count on dist1 + hier.createOrReplaceTempView("hier") + hier.count() // if this is removed it passes + + val base = sqlContext.sparkSession.sparkContext.parallelize(Seq( + Tuple1("A10"), + Tuple1("A11"), + Tuple1("A20"), + Tuple1("A21"), + Tuple1("B10"), + Tuple1("B11"), + Tuple1("B20"), + Tuple1("B21") + )).toDF("id") + base.createOrReplaceTempView("base") + + val dist1 = spark.sql(""" + SELECT parent level1 + FROM base INNER JOIN hier h1 ON base.id = h1.son + GROUP BY parent""") + + dist1.createOrReplaceTempView("dist1") + // dist1.count() // or put a count here + + val dist2 = spark.sql(""" + SELECT parent level2 + FROM dist1 INNER JOIN hier h2 ON dist1.level1 = h2.son + GROUP BY parent""") + + val plan = dist2.queryExecution.executedPlan + // For debug print tree string with partitioning suffix + // println(plan.treeString(verbose = true, addSuffix = true)) + + dist2.createOrReplaceTempView("dist2") + checkAnswer(dist2, Row("A") :: Row("B") :: Nil) + + assert(plan.isInstanceOf[WholeStageCodegenExec]) + assert(plan.outputPartitioning === UnknownPartitioning(5)) + + val agg = plan.children.head + + assert(agg.isInstanceOf[HashAggregateExec]) + assert(agg.outputPartitioning === UnknownPartitioning(5)) + + // Skip input adaptor + val exchange = agg.children.head.children.head + assert(exchange.isInstanceOf[Exchange]) + assert(exchange.outputPartitioning.isInstanceOf[HashPartitioning]) + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y")