Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ sealed trait Partitioning {
* produced by `A` could have also been produced by `B`.
*/
def guarantees(other: Partitioning): Boolean = this == other

/**
* Returns the partitioning scheme that is valid under restriction to a given set of output
* attributes. If the partitioning is an [[Expression]] then the attributes that it depends on
* must be in the outputSet otherwise the attribute leaks.
*/
def restrict(outputSet: AttributeSet): Partitioning = this match {
Copy link
Member

@gatorsmile gatorsmile Aug 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are refactoring the concept of distribution and partitioning in the PR #19080

Could you provide your inputs in that PR first? Thanks!

case p: Expression if !p.references.subsetOf(outputSet) => UnknownPartitioning(numPartitions)
case _ => this
}
}

object Partitioning {
Expand Down Expand Up @@ -356,6 +366,14 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
override def guarantees(other: Partitioning): Boolean =
partitionings.exists(_.guarantees(other))

override def restrict(outputSet: AttributeSet): Partitioning = {
partitionings.map(_.restrict(outputSet)).filter(!_.isInstanceOf[UnknownPartitioning]) match {
case Nil => UnknownPartitioning(numPartitions)
case singlePartitioning :: Nil => singlePartitioning
case more => PartitioningCollection(more)
}
}

override def toString: String = {
partitionings.map(_.toString).mkString("(", " or ", ")")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, AttributeSet, InterpretedMutableProjection, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.DataTypes

class PartitioningSuite extends SparkFunSuite {
test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
Expand Down Expand Up @@ -52,4 +53,41 @@ class PartitioningSuite extends SparkFunSuite {
assert(partitioningA.guarantees(partitioningA))
assert(partitioningA.compatibleWith(partitioningA))
}

test("restriction of Partitioning works") {
val n = 5

val a1 = AttributeReference("a1", DataTypes.IntegerType)()
val a2 = AttributeReference("a2", DataTypes.IntegerType)()
val a3 = AttributeReference("a3", DataTypes.IntegerType)()

val hashPartitioning = HashPartitioning(Seq(a1, a2), n)

assert(hashPartitioning.restrict(AttributeSet(Seq())) === UnknownPartitioning(n))
assert(hashPartitioning.restrict(AttributeSet(Seq(a1))) === UnknownPartitioning(n))
assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2))) === hashPartitioning)
assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2, a3))) === hashPartitioning)

val so1 = SortOrder(a1, Ascending)
val so2 = SortOrder(a2, Ascending)

val rangePartitioning1 = RangePartitioning(Seq(so1), n)
val rangePartitioning2 = RangePartitioning(Seq(so1, so2), n)

assert(rangePartitioning2.restrict(AttributeSet(Seq())) == UnknownPartitioning(n))
assert(rangePartitioning2.restrict(AttributeSet(Seq(a1))) == UnknownPartitioning(n))
assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2))) === rangePartitioning2)
assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2, a3))) === rangePartitioning2)

assert(SinglePartition.restrict(AttributeSet(a1)) === SinglePartition)

val all = Seq(hashPartitioning, rangePartitioning1, rangePartitioning2)
val partitioningCollection = PartitioningCollection(all)

assert(partitioningCollection.restrict(AttributeSet(Seq())) == UnknownPartitioning(n))
assert(partitioningCollection.restrict(AttributeSet(Seq(a1))) == rangePartitioning1)
assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2))) == partitioningCollection)
assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2, a3))) == partitioningCollection)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
false
}

override def verboseStringWithSuffix: String = {
s"$verboseString $outputPartitioning"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for debugging this, do we really need to print out output partitioning always?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't change anything that is in common use, one has to do plan.treeString(verbose = true, addSuffix = true) to get it. I would argue for keeping it for any future debugging.


/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SparkSession.setActiveSession(sqlContext.sparkSession)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false): StringBuilder = {
child.generateTreeString(depth, lastChildren, builder, verbose, "")
child.generateTreeString(depth, lastChildren, builder, verbose, "", addSuffix)
}
}

Expand Down Expand Up @@ -456,7 +456,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false): StringBuilder = {
child.generateTreeString(depth, lastChildren, builder, verbose, "*")
child.generateTreeString(depth, lastChildren, builder, verbose, "*", addSuffix)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class HashAggregateExec(

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet)

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet)
}


Expand Down
68 changes: 68 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ import scala.language.existentials
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand All @@ -38,6 +42,70 @@ class JoinSuite extends QueryTest with SharedSQLContext {
df.queryExecution.optimizedPlan.stats.sizeInBytes
}

test("SPARK-16683 Repeated joins to same table can leak attributes via partitioning") {
val hier = sqlContext.sparkSession.sparkContext.parallelize(Seq(
("A10", "A1"),
("A11", "A1"),
("A20", "A2"),
("A21", "A2"),
("B10", "B1"),
("B11", "B1"),
("B20", "B2"),
("B21", "B2"),
("A1", "A"),
("A2", "A"),
("B1", "B"),
("B2", "B")
)).toDF("son", "parent").cache() // passes if cache is removed but with count on dist1
hier.createOrReplaceTempView("hier")
hier.count() // if this is removed it passes

val base = sqlContext.sparkSession.sparkContext.parallelize(Seq(
Tuple1("A10"),
Tuple1("A11"),
Tuple1("A20"),
Tuple1("A21"),
Tuple1("B10"),
Tuple1("B11"),
Tuple1("B20"),
Tuple1("B21")
)).toDF("id")
base.createOrReplaceTempView("base")

val dist1 = spark.sql("""
SELECT parent level1
FROM base INNER JOIN hier h1 ON base.id = h1.son
GROUP BY parent""")

dist1.createOrReplaceTempView("dist1")
// dist1.count() // or put a count here

val dist2 = spark.sql("""
SELECT parent level2
FROM dist1 INNER JOIN hier h2 ON dist1.level1 = h2.son
GROUP BY parent""")

val plan = dist2.queryExecution.executedPlan
// For debug print tree string with partitioning suffix
// println(plan.treeString(verbose = true, addSuffix = true))

dist2.createOrReplaceTempView("dist2")
checkAnswer(dist2, Row("A") :: Row("B") :: Nil)

assert(plan.isInstanceOf[WholeStageCodegenExec])
assert(plan.outputPartitioning === UnknownPartitioning(5))

val agg = plan.children.head

assert(agg.isInstanceOf[HashAggregateExec])
assert(agg.outputPartitioning === UnknownPartitioning(5))

// Skip input adaptor
val exchange = agg.children.head.children.head
assert(exchange.isInstanceOf[Exchange])
assert(exchange.outputPartitioning.isInstanceOf[HashPartitioning])
}

test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
Expand Down