From b05e6303a19926327525a9c2ffa399d68fca3911 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 29 Jun 2017 10:43:58 -0500 Subject: [PATCH 1/4] fix via partitioning restriction --- .../plans/physical/partitioning.scala | 18 +++++ .../sql/catalyst/PartitioningSuite.scala | 43 +++++++++++- .../spark/sql/execution/SparkPlan.scala | 4 ++ .../sql/execution/WholeStageCodegenExec.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../execution/basicPhysicalOperators.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 68 +++++++++++++++++++ 7 files changed, 135 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..44952442cce86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -171,6 +171,16 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other + + /** + * Returns the partitioning scheme that is valid under restriction to a given set of output + * attributes. If the partitioning is an [[Expression]] then the attributes that it depends on + * must be in the outputSet otherwise the attribute leaks. + */ + def restrict(outputSet: AttributeSet): Partitioning = this match { + case p: Expression if !p.references.subsetOf(outputSet) => UnknownPartitioning(numPartitions) + case _ => this + } } object Partitioning { @@ -356,6 +366,14 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def guarantees(other: Partitioning): Boolean = partitionings.exists(_.guarantees(other)) + override def restrict(outputSet: AttributeSet): Partitioning = { + partitionings.map(_.restrict(outputSet)).filter(!_.isInstanceOf[UnknownPartitioning]) match { + case Nil => UnknownPartitioning(numPartitions) + case singlePartitioning :: Nil => singlePartitioning + case more => PartitioningCollection(more) + } + } + override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala index 5b802ccc637dd..da709f8568bda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, AttributeSet, InterpretedMutableProjection, Literal, NullsFirst, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.DataTypes class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { @@ -52,4 +54,41 @@ class PartitioningSuite extends SparkFunSuite { assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } + + test("restriction of Partitioning works") { + val n = 5 + + val a1 = AttributeReference("a1", DataTypes.IntegerType)() + val a2 = AttributeReference("a2", DataTypes.IntegerType)() + val a3 = AttributeReference("a3", DataTypes.IntegerType)() + + val hashPartitioning = HashPartitioning(Seq(a1, a2), n) + + assert(hashPartitioning.restrict(AttributeSet(Seq())) === UnknownPartitioning(n)) + assert(hashPartitioning.restrict(AttributeSet(Seq(a1))) === UnknownPartitioning(n)) + assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2))) === hashPartitioning) + assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2, a3))) === hashPartitioning) + + val so1 = SortOrder(a1, Ascending) + val so2 = SortOrder(a2, Ascending) + + val rangePartitioning1 = RangePartitioning(Seq(so1), n) + val rangePartitioning2 = RangePartitioning(Seq(so1, so2), n) + + assert(rangePartitioning2.restrict(AttributeSet(Seq())) == UnknownPartitioning(n)) + assert(rangePartitioning2.restrict(AttributeSet(Seq(a1))) == UnknownPartitioning(n)) + assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2))) === rangePartitioning2) + assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2, a3))) === rangePartitioning2) + + assert(SinglePartition.restrict(AttributeSet(a1)) === SinglePartition) + + val all = Seq(hashPartitioning, rangePartitioning1, rangePartitioning2) + val partitioningCollection = PartitioningCollection(all) + + assert(partitioningCollection.restrict(AttributeSet(Seq())) == UnknownPartitioning(n)) + assert(partitioningCollection.restrict(AttributeSet(Seq(a1))) == rangePartitioning1) + assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2))) == partitioningCollection) + assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2, a3))) == partitioningCollection) + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index c7277c21cebb2..3b285063bfa35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -65,6 +65,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } + override def verboseStringWithSuffix: String = { + s"$verboseString $outputPartitioning" + } + /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SparkSession.setActiveSession(sqlContext.sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1007a7d55691b..1b3559056051d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -273,7 +273,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "") + child.generateTreeString(depth, lastChildren, builder, verbose, "", addSuffix) } } @@ -448,7 +448,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + child.generateTreeString(depth, lastChildren, builder, verbose, "*", addSuffix) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 56f61c30c4a38..da7ec2fd03c06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -64,7 +64,7 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet) override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2151c339b9b87..4c44d4606c2f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -80,7 +80,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 895ca196a7a51..7ac1c7195a31e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -26,6 +26,10 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.Exchange class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -36,6 +40,70 @@ class JoinSuite extends QueryTest with SharedSQLContext { df.queryExecution.optimizedPlan.stats.sizeInBytes } + test("SPARK-16683 Repeated joins to same table can leak attributes via partitioning") { + val hier = sqlContext.sparkSession.sparkContext.parallelize(Seq( + ("A10", "A1"), + ("A11", "A1"), + ("A20", "A2"), + ("A21", "A2"), + ("B10", "B1"), + ("B11", "B1"), + ("B20", "B2"), + ("B21", "B2"), + ("A1", "A"), + ("A2", "A"), + ("B1", "B"), + ("B2", "B") + )).toDF("son", "parent").cache() // passes if cache is removed but with count on dist1 + hier.createOrReplaceTempView("hier") + hier.count() // if this is removed it passes + + val base = sqlContext.sparkSession.sparkContext.parallelize(Seq( + Tuple1("A10"), + Tuple1("A11"), + Tuple1("A20"), + Tuple1("A21"), + Tuple1("B10"), + Tuple1("B11"), + Tuple1("B20"), + Tuple1("B21") + )).toDF("id") + base.createOrReplaceTempView("base") + + val dist1 = spark.sql(""" + SELECT parent level1 + FROM base INNER JOIN hier h1 ON base.id = h1.son + GROUP BY parent""") + + dist1.createOrReplaceTempView("dist1") + // dist1.count() // or put a count here + + val dist2 = spark.sql(""" + SELECT parent level2 + FROM dist1 INNER JOIN hier h2 ON dist1.level1 = h2.son + GROUP BY parent""") + + val plan = dist2.queryExecution.executedPlan + // For debug print tree string with partitioning suffix + // println(plan.treeString(verbose = true, addSuffix = true)) + + dist2.createOrReplaceTempView("dist2") + checkAnswer(dist2, Row("A") :: Row("B") :: Nil) + + assert(plan.isInstanceOf[WholeStageCodegenExec]) + assert(plan.outputPartitioning === UnknownPartitioning(5)) + + val agg = plan.children.head + + assert(agg.isInstanceOf[HashAggregateExec]) + assert(agg.outputPartitioning === UnknownPartitioning(5)) + + // Skip input adaptor + val exchange = agg.children.head.children.head + assert(exchange.isInstanceOf[Exchange]) + assert(exchange.outputPartitioning.isInstanceOf[HashPartitioning]) + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") From cd5aa8095026cfa03aa545a56c0a2690e6304936 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 21 Jul 2017 00:23:23 -0500 Subject: [PATCH 2/4] scala style --- .../org/apache/spark/sql/catalyst/PartitioningSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala index da709f8568bda..3112be5fe985d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite - -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, AttributeSet, InterpretedMutableProjection, Literal, NullsFirst, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, AttributeSet, InterpretedMutableProjection, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.DataTypes From f41811f620c6e2bcd23af52e302de1ce06e3230c Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 31 Jul 2017 11:11:18 -0500 Subject: [PATCH 3/4] indent code --- .../test/scala/org/apache/spark/sql/JoinSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7ac1c7195a31e..a4b4157e9890a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -71,17 +71,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { base.createOrReplaceTempView("base") val dist1 = spark.sql(""" - SELECT parent level1 - FROM base INNER JOIN hier h1 ON base.id = h1.son - GROUP BY parent""") + SELECT parent level1 + FROM base INNER JOIN hier h1 ON base.id = h1.son + GROUP BY parent""") dist1.createOrReplaceTempView("dist1") // dist1.count() // or put a count here val dist2 = spark.sql(""" - SELECT parent level2 - FROM dist1 INNER JOIN hier h2 ON dist1.level1 = h2.son - GROUP BY parent""") + SELECT parent level2 + FROM dist1 INNER JOIN hier h2 ON dist1.level1 = h2.son + GROUP BY parent""") val plan = dist2.queryExecution.executedPlan // For debug print tree string with partitioning suffix From 0f21237b61a59bfcbf384866e06323a667154924 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 31 Aug 2017 10:33:04 -0500 Subject: [PATCH 4/4] fix style --- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ec09c14683a22..8c7f66ff46446 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,13 +24,13 @@ import scala.language.existentials import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType class JoinSuite extends QueryTest with SharedSQLContext {