From d102da370babba06cfa1a349a98bbe56dda3d056 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jun 2018 16:55:47 -0700 Subject: [PATCH 1/4] StreamingSymmetricHashJoinExec should require HashClusteredPartitioning from children --- .../plans/physical/partitioning.scala | 69 +++++++++---------- .../exchange/EnsureRequirements.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 4 +- .../sql/streaming/StreamingJoinSuite.scala | 14 ++++ 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..49396d47494a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -68,50 +68,42 @@ case object AllTuples extends Distribution { } } -/** - * Represents data where tuples that share the same values for the `clustering` - * [[Expression Expressions]] will be co-located. Based on the context, this - * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. - */ -case class ClusteredDistribution( - clustering: Seq[Expression], - requiredNumPartitions: Option[Int] = None) extends Distribution { +abstract class ClusteredDistributionBase(exprs: Seq[Expression]) extends Distribution { require( - clustering != Nil, - "The clustering expressions of a ClusteredDistribution should not be Nil. " + + exprs.nonEmpty, + s"The clustering expressions of a ${getClass.getSimpleName} should not be empty. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") override def createPartitioning(numPartitions: Int): Partitioning = { assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, - s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"This ${getClass.getSimpleName} requires ${requiredNumPartitions.get} partitions, but " + s"the actual number of partitions is $numPartitions.") - HashPartitioning(clustering, numPartitions) + HashPartitioning(exprs, numPartitions) } } /** - * Represents data where tuples have been clustered according to the hash of the given - * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * Represents data where tuples that share the same values for the `clustering` + * [[Expression Expressions]] will be co-located. Based on the context, this + * can mean such tuples are either co-located in the same partition or they will be contiguous + * within a single partition. + */ +case class ClusteredDistribution( + clustering: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(clustering) + +/** + * Represents data where tuples have been clustered according to the hash of the given expressions. + * The hash function is defined as [[HashPartitioning.partitionIdExpression]], so only * [[HashPartitioning]] can satisfy this distribution. * * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { - require( - expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + - "An AllTuples should be used to represent a distribution that only has " + - "a single partition.") - - override def requiredNumPartitions: Option[Int] = None - - override def createPartitioning(numPartitions: Int): Partitioning = { - HashPartitioning(expressions, numPartitions) - } -} +case class HashClusteredDistribution( + hashExprs: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(hashExprs) /** * Represents data where tuples have been ordered according to the `ordering` @@ -207,15 +199,18 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = { super.satisfies(required) || { - required match { - case h: HashClusteredDistribution => - expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { - case (l, r) => l.semanticEquals(r) - } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) - case _ => false + val satisfyNumPartitions = required.requiredNumPartitions.isEmpty || + required.requiredNumPartitions.get == numPartitions + satisfyNumPartitions && { + required match { + case h: HashClusteredDistribution => + expressions.length == h.hashExprs.length && expressions.zip(h.hashExprs).forall { + case (l, r) => l.semanticEquals(r) + } + case c: ClusteredDistribution => + expressions.forall(x => c.clustering.exists(_.semanticEquals(x))) + case _ => false + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ad95879d86f42..f84cb2b91b292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -73,7 +73,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. val supportsDistribution = requiredChildDistributions.forall { dist => - dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] + dist.isInstanceOf[ClusteredDistributionBase] } children.length > 1 && supportsDistribution } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index afa664eb76525..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1f62357e6d09e..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input3, 5, 10), CheckNewAnswer((5, 10, 5, 15, 5, 25))) } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) + } } From 6fc7913a77c050e40f879d6d8611c855e488ed11 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Jun 2018 10:45:33 -0700 Subject: [PATCH 2/4] address comment --- .../plans/physical/partitioning.scala | 48 ++++++++----------- .../v2/DataSourcePartitioning.scala | 24 +++++----- 2 files changed, 31 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 49396d47494a4..2c9bf598dfb6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -162,8 +162,10 @@ trait Partitioning { def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 - case _ => false + case _ => required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) } + + protected def satisfies0(required: Distribution): Boolean = false } case class UnknownPartitioning(numPartitions: Int) extends Partitioning @@ -178,9 +180,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } @@ -197,21 +198,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { - val satisfyNumPartitions = required.requiredNumPartitions.isEmpty || - required.requiredNumPartitions.get == numPartitions - satisfyNumPartitions && { - required match { - case h: HashClusteredDistribution => - expressions.length == h.hashExprs.length && expressions.zip(h.hashExprs).forall { - case (l, r) => l.semanticEquals(r) - } - case c: ClusteredDistribution => - expressions.forall(x => c.clustering.exists(_.semanticEquals(x))) - case _ => false + override def satisfies0(required: Distribution): Boolean = { + required match { + case h: HashClusteredDistribution => + expressions.length == h.hashExprs.length && expressions.zip(h.hashExprs).forall { + case (l, r) => l.semanticEquals(r) } - } + case c: ClusteredDistribution => + expressions.forall(x => c.clustering.exists(_.semanticEquals(x))) + case _ => false } } @@ -241,17 +236,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { - required match { - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) - case _ => false - } + override def satisfies0(required: Distribution): Boolean = { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case _ => false } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..efedae9ff1fb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,20 +30,18 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { - required match { - case d: physical.ClusteredDistribution if isCandidate(d.clustering) => - val attrs = d.clustering.map(_.asInstanceOf[Attribute]) - partitioning.satisfy( - new ClusteredDistribution(attrs.map { a => - val name = colNames.get(a) - assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") - name.get - }.toArray)) + override def satisfies0(required: physical.Distribution): Boolean = { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) - case _ => false - } + case _ => false } } From 08da2e6717ce4693fcaf33d021513f478f33d2a4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Jun 2018 22:27:14 -0700 Subject: [PATCH 3/4] address comment --- .../plans/physical/partitioning.scala | 59 +++-- .../sql/catalyst/DistributionSuite.scala | 201 +++++++++++++++--- .../v2/DataSourcePartitioning.scala | 22 +- 3 files changed, 215 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2c9bf598dfb6e..65f995a36e6ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -155,17 +155,26 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 - case _ => required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + case _ => false } - - protected def satisfies0(required: Distribution): Boolean = false } case class UnknownPartitioning(numPartitions: Int) extends Partitioning @@ -199,14 +208,16 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def dataType: DataType = IntegerType override def satisfies0(required: Distribution): Boolean = { - required match { - case h: HashClusteredDistribution => - expressions.length == h.hashExprs.length && expressions.zip(h.hashExprs).forall { - case (l, r) => l.semanticEquals(r) - } - case c: ClusteredDistribution => - expressions.forall(x => c.clustering.exists(_.semanticEquals(x))) - case _ => false + super.satisfies0(required) || { + required match { + case d: HashClusteredDistribution => + expressions.length == d.hashExprs.length && expressions.zip(d.hashExprs).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case _ => false + } } } @@ -237,13 +248,15 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def dataType: DataType = IntegerType override def satisfies0(required: Distribution): Boolean = { - required match { - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, _) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) - case _ => false + super.satisfies0(required) || { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case _ => false + } } } } @@ -282,7 +295,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -297,7 +310,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index efedae9ff1fb3..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -31,17 +31,19 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() override def satisfies0(required: physical.Distribution): Boolean = { - required match { - case d: physical.ClusteredDistribution if isCandidate(d.clustering) => - val attrs = d.clustering.map(_.asInstanceOf[Attribute]) - partitioning.satisfy( - new ClusteredDistribution(attrs.map { a => - val name = colNames.get(a) - assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") - name.get - }.toArray)) + super.satisfies0(required) || { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) - case _ => false + case _ => false + } } } From 72466b0026469379ff1ccce350b538f66c0b384a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 Jun 2018 11:16:15 -0700 Subject: [PATCH 4/4] reduce diff --- .../plans/physical/partitioning.scala | 53 +++++++++++-------- .../exchange/EnsureRequirements.scala | 2 +- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 65f995a36e6ab..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -68,42 +68,53 @@ case object AllTuples extends Distribution { } } -abstract class ClusteredDistributionBase(exprs: Seq[Expression]) extends Distribution { +/** + * Represents data where tuples that share the same values for the `clustering` + * [[Expression Expressions]] will be co-located. Based on the context, this + * can mean such tuples are either co-located in the same partition or they will be contiguous + * within a single partition. + */ +case class ClusteredDistribution( + clustering: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( - exprs.nonEmpty, - s"The clustering expressions of a ${getClass.getSimpleName} should not be empty. " + + clustering != Nil, + "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") override def createPartitioning(numPartitions: Int): Partitioning = { assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, - s"This ${getClass.getSimpleName} requires ${requiredNumPartitions.get} partitions, but " + + s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + s"the actual number of partitions is $numPartitions.") - HashPartitioning(exprs, numPartitions) + HashPartitioning(clustering, numPartitions) } } /** - * Represents data where tuples that share the same values for the `clustering` - * [[Expression Expressions]] will be co-located. Based on the context, this - * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. - */ -case class ClusteredDistribution( - clustering: Seq[Expression], - requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(clustering) - -/** - * Represents data where tuples have been clustered according to the hash of the given expressions. - * The hash function is defined as [[HashPartitioning.partitionIdExpression]], so only + * Represents data where tuples have been clustered according to the hash of the given + * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only * [[HashPartitioning]] can satisfy this distribution. * * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ case class HashClusteredDistribution( - hashExprs: Seq[Expression], - requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(hashExprs) + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(expressions, numPartitions) + } +} /** * Represents data where tuples have been ordered according to the `ordering` @@ -210,8 +221,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { required match { - case d: HashClusteredDistribution => - expressions.length == d.hashExprs.length && expressions.zip(d.hashExprs).forall { + case h: HashClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } case ClusteredDistribution(requiredClustering, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f84cb2b91b292..ad95879d86f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -73,7 +73,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. val supportsDistribution = requiredChildDistributions.forall { dist => - dist.isInstanceOf[ClusteredDistributionBase] + dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] } children.length > 1 && supportsDistribution }