From 880d8e980535b9a283147fd922b79c22069cafcc Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 26 Mar 2015 03:03:58 -0700 Subject: [PATCH 01/32] sort merge join for spark sql --- .../plans/physical/partitioning.scala | 49 ++++++ .../scala/org/apache/spark/sql/SQLConf.scala | 7 + .../apache/spark/sql/execution/Exchange.scala | 25 ++- .../spark/sql/execution/SparkStrategies.scala | 7 + .../sql/execution/joins/SortMergeJoin.scala | 145 ++++++++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 7 +- 6 files changed, 235 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..147b48047f530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,6 +75,21 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Represents data where tuples have been ordered according to the `clustering` + * [[Expression Expressions]]. This is a strictly stronger guarantee than + * [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted + * by the expressions. + */ +case class ClusteredOrderedDistribution(clustering: Seq[Expression]) + extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -162,6 +177,40 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. And rows within the same partition are sorted by the expressions. + */ +case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression + with Partitioning { + + override def children = expressions + override def nullable = false + override def dataType = IntegerType + + private[this] lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredOrderedDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case ClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case h: HashSortedPartitioning if h == this => true + case _ => false + } + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} + /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner the following diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4815620c6fe57..e2e07c1b804a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -27,6 +27,7 @@ private[spark] object SQLConf { val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" + val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" @@ -143,6 +144,12 @@ private[sql] class SQLConf extends Serializable { private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt + /** + * By default it will choose sort merge join. + */ + private[spark] def autoSortMergeJoin: Boolean = + getConf(AUTO_SORTMERGEJOIN, true.toString).toBoolean + /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, * it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..6c01dee9a8969 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -73,6 +72,26 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) + case HashSortedPartitioning(expressions, numPartitions) => + val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + iter.map(r => (hashExpressions(r).copy(), r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } + } + val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + val part = new HashPartitioner(numPartitions) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) + case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} @@ -173,6 +192,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), child) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) + case (ClusteredOrderedDistribution(clustering), child) => + addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child) case (OrderedDistribution(ordering), child) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, child) => child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f754fa770d1b5..72f41e4bd7685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.autoSortMergeJoin => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, Inner, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 0000000000000..3c0ab080e7f4d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] = + ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil + + private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending)) + private val ordering: RowOrdering = new RowOrdering(orders, left.output) + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var read: Boolean = false + private[this] var currentlMatches: CompactBuffer[Row] = _ + private[this] var currentrMatches: CompactBuffer[Row] = _ + private[this] var currentlPosition: Int = -1 + private[this] var currentrPosition: Int = -1 + + override final def hasNext: Boolean = + (currentlPosition != -1 && currentlPosition < currentlMatches.size) || + (leftIter.hasNext && rightIter.hasNext && nextMatchingPair) + + override final def next(): Row = { + val joinedRow = + joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition)) + currentrPosition += 1 + if (currentrPosition >= currentrMatches.size) { + currentlPosition += 1 + currentrPosition = 0 + } + joinedRow + } + + /** + * Searches the left/right iterator for the next rows that matches. + * + * @return true if the search is successful, and false if the left/right iterator runs out + * of tuples. + */ + private def nextMatchingPair(): Boolean = { + currentlPosition = -1 + currentlMatches = null + if (rightElement == null) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } + while (currentlMatches == null && leftIter.hasNext) { + if (!read) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } + while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } + currentrMatches = new CompactBuffer[Row]() + while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) { + currentrMatches += rightElement + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } + if (ordering.compare(leftKey, rightKey) == 0) { + currentrMatches += rightElement + } + if (currentrMatches.size > 0) { + // there exists rows match in right table, should search left table + currentlMatches = new CompactBuffer[Row]() + val leftMatch = leftKey.copy() + while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) { + currentlMatches += leftElement + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } + if (ordering.compare(leftKey, leftMatch) == 0) { + currentlMatches += leftElement + } else { + read = true + } + } + } + + if (currentlMatches == null) { + false + } else { + currentlPosition = 0 + currentrPosition = 0 + true + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e4dee87849fd4..bba2f223c55dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -75,9 +76,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]), From 4464f16a0d1a776fe1b64b3bdb02c2b4ebbc501b Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 29 Mar 2015 21:16:11 -0700 Subject: [PATCH 02/32] fix error --- .../sql/execution/joins/SortMergeJoin.scala | 90 +++++++++++-------- 1 file changed, 53 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 3c0ab080e7f4d..1bf3baa75ace0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -41,7 +41,7 @@ case class SortMergeJoin( override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] = + override def requiredChildDistribution: Seq[Distribution] = ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending)) @@ -62,7 +62,6 @@ case class SortMergeJoin( private[this] var rightElement: Row = _ private[this] var leftKey: Row = _ private[this] var rightKey: Row = _ - private[this] var read: Boolean = false private[this] var currentlMatches: CompactBuffer[Row] = _ private[this] var currentrMatches: CompactBuffer[Row] = _ private[this] var currentlPosition: Int = -1 @@ -70,7 +69,7 @@ case class SortMergeJoin( override final def hasNext: Boolean = (currentlPosition != -1 && currentlPosition < currentlMatches.size) || - (leftIter.hasNext && rightIter.hasNext && nextMatchingPair) + nextMatchingPair override final def next(): Row = { val joinedRow = @@ -83,6 +82,32 @@ case class SortMergeJoin( joinedRow } + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + // initialize iterator + private def initialize() = { + fetchLeft() + fetchRight() + } + + initialize() + /** * Searches the left/right iterator for the next rows that matches. * @@ -92,42 +117,33 @@ case class SortMergeJoin( private def nextMatchingPair(): Boolean = { currentlPosition = -1 currentlMatches = null - if (rightElement == null) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + var stop: Boolean = false + while (!stop && leftElement != null && rightElement != null) { + if (ordering.compare(leftKey, rightKey) > 0) + fetchRight() + else if (ordering.compare(leftKey, rightKey) < 0) + fetchLeft() + else + stop = true } - while (currentlMatches == null && leftIter.hasNext) { - if (!read) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } - while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } - currentrMatches = new CompactBuffer[Row]() - while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) { + currentrMatches = new CompactBuffer[Row]() + while (stop && rightElement != null) { + if (!rightKey.anyNull) currentrMatches += rightElement - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } - if (ordering.compare(leftKey, rightKey) == 0) { - currentrMatches += rightElement - } - if (currentrMatches.size > 0) { - // there exists rows match in right table, should search left table - currentlMatches = new CompactBuffer[Row]() - val leftMatch = leftKey.copy() - while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) { - currentlMatches += leftElement - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } - if (ordering.compare(leftKey, leftMatch) == 0) { + fetchRight() + if (ordering.compare(leftKey, rightKey) != 0) + stop = false + } + if (currentrMatches.size > 0) { + stop = false + currentlMatches = new CompactBuffer[Row]() + val leftMatch = leftKey.copy() + while (!stop && leftElement != null) { + if (!leftKey.anyNull) currentlMatches += leftElement - } else { - read = true - } + fetchLeft() + if (ordering.compare(leftKey, leftMatch) != 0) + stop = true } } From 95db7ad8f914e47045b53a39f5bad50aee511be9 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 29 Mar 2015 21:24:30 -0700 Subject: [PATCH 03/32] fix brackets for if-statement --- .../sql/execution/joins/SortMergeJoin.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 1bf3baa75ace0..c241a7ae69cde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -119,31 +119,36 @@ case class SortMergeJoin( currentlMatches = null var stop: Boolean = false while (!stop && leftElement != null && rightElement != null) { - if (ordering.compare(leftKey, rightKey) > 0) + if (ordering.compare(leftKey, rightKey) > 0) { fetchRight() - else if (ordering.compare(leftKey, rightKey) < 0) + } else if (ordering.compare(leftKey, rightKey) < 0) { fetchLeft() - else + } else { stop = true + } } currentrMatches = new CompactBuffer[Row]() while (stop && rightElement != null) { - if (!rightKey.anyNull) + if (!rightKey.anyNull) { currentrMatches += rightElement + } fetchRight() - if (ordering.compare(leftKey, rightKey) != 0) + if (ordering.compare(leftKey, rightKey) != 0) { stop = false + } } if (currentrMatches.size > 0) { stop = false currentlMatches = new CompactBuffer[Row]() val leftMatch = leftKey.copy() while (!stop && leftElement != null) { - if (!leftKey.anyNull) + if (!leftKey.anyNull) { currentlMatches += leftElement + } fetchLeft() - if (ordering.compare(leftKey, leftMatch) != 0) + if (ordering.compare(leftKey, leftMatch) != 0) { stop = true + } } } From 303b6da2ef427d08e367ab85e29fd18a6572c90a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 31 Mar 2015 00:48:35 -0700 Subject: [PATCH 04/32] fix several errors --- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../sql/execution/joins/SortMergeJoin.scala | 88 ++++++++++--------- .../spark/sql/hive/StatisticsSuite.scala | 2 + 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6c01dee9a8969..58c62997a843e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -86,7 +86,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + val ordering = new RowOrdering(sortingExpressions, child.output) val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index c241a7ae69cde..7048f91f80eab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -67,17 +67,21 @@ case class SortMergeJoin( private[this] var currentlPosition: Int = -1 private[this] var currentrPosition: Int = -1 - override final def hasNext: Boolean = - (currentlPosition != -1 && currentlPosition < currentlMatches.size) || - nextMatchingPair + override final def hasNext: Boolean = currentlPosition != -1 || nextMatchingPair override final def next(): Row = { + if (!hasNext) { + return null + } val joinedRow = joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition)) currentrPosition += 1 if (currentrPosition >= currentrMatches.size) { currentlPosition += 1 currentrPosition = 0 + if (currentlPosition >= currentlMatches.size) { + currentlPosition = -1 + } } joinedRow } @@ -100,13 +104,13 @@ case class SortMergeJoin( } } - // initialize iterator - private def initialize() = { + private def fetchFirst() = { fetchLeft() fetchRight() + currentrPosition = 0 } - - initialize() + // initialize iterator + fetchFirst() /** * Searches the left/right iterator for the next rows that matches. @@ -115,49 +119,49 @@ case class SortMergeJoin( * of tuples. */ private def nextMatchingPair(): Boolean = { - currentlPosition = -1 - currentlMatches = null - var stop: Boolean = false - while (!stop && leftElement != null && rightElement != null) { - if (ordering.compare(leftKey, rightKey) > 0) { - fetchRight() - } else if (ordering.compare(leftKey, rightKey) < 0) { - fetchLeft() - } else { - stop = true + if (currentlPosition > -1) { + true + } else { + currentlPosition = -1 + currentlMatches = null + var stop: Boolean = false + while (!stop && leftElement != null && rightElement != null) { + if (ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull) { + stop = true + } else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { + fetchRight() + } else { //if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) + fetchLeft() + } } - } - currentrMatches = new CompactBuffer[Row]() - while (stop && rightElement != null) { - if (!rightKey.anyNull) { + currentrMatches = new CompactBuffer[Row]() + while (stop && rightElement != null) { currentrMatches += rightElement + fetchRight() + if (ordering.compare(leftKey, rightKey) != 0) { + stop = false + } } - fetchRight() - if (ordering.compare(leftKey, rightKey) != 0) { + if (currentrMatches.size > 0) { stop = false - } - } - if (currentrMatches.size > 0) { - stop = false - currentlMatches = new CompactBuffer[Row]() - val leftMatch = leftKey.copy() - while (!stop && leftElement != null) { - if (!leftKey.anyNull) { + currentlMatches = new CompactBuffer[Row]() + val leftMatch = leftKey.copy() + while (!stop && leftElement != null) { currentlMatches += leftElement - } - fetchLeft() - if (ordering.compare(leftKey, leftMatch) != 0) { - stop = true + fetchLeft() + if (ordering.compare(leftKey, leftMatch) != 0) { + stop = true + } } } - } - if (currentlMatches == null) { - false - } else { - currentlPosition = 0 - currentrPosition = 0 - true + if (currentlMatches == null) { + false + } else { + currentlPosition = 0 + currentrPosition = 0 + true + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ccd0e5aa51f95..dc1d9fbd299e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -144,6 +144,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { expectedAnswer: Seq[Row], ct: ClassTag[_]) = { before() + conf.setConf("spark.sql.autoSortMergeJoin", "false") var df = sql(query) @@ -178,6 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") } + conf.setConf("spark.sql.autoSortMergeJoin", "true") after() } From 57baa4033585aef8b760586c10e7e86c1eefa469 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 1 Apr 2015 00:08:30 -0700 Subject: [PATCH 05/32] fix sort eval bug --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 58c62997a843e..63d2e92697c04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -85,7 +85,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(hashExpressions(r), r)) } } - val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) + val sortingExpressions = expressions.zipWithIndex.map { + case (exp, index) => + new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending) + } val ordering = new RowOrdering(sortingExpressions, child.output) val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) From 2edd2351254aeef8ad2ff7238fba509fb997c217 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 1 Apr 2015 00:29:11 -0700 Subject: [PATCH 06/32] fix outputpartitioning --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 7048f91f80eab..03cbb680020fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -39,7 +39,7 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = HashSortedPartitioning(leftKeys, 0) override def requiredChildDistribution: Seq[Distribution] = ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil From e3ec0964c3e0f69320804a2917c9d932a12406e8 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 1 Apr 2015 00:49:32 -0700 Subject: [PATCH 07/32] fix comment style.. --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 03cbb680020fd..a2719c5060496 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -130,7 +130,7 @@ case class SortMergeJoin( stop = true } else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { fetchRight() - } else { //if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) + } else { // if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) fetchLeft() } } From 42fca0eafd187693bde3b89fa8e8d16ed22d4e79 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 1 Apr 2015 01:34:33 -0700 Subject: [PATCH 08/32] code clean --- .../sql/execution/joins/SortMergeJoin.scala | 84 +++++++++---------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index a2719c5060496..ec6a05542659d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -62,28 +62,29 @@ case class SortMergeJoin( private[this] var rightElement: Row = _ private[this] var leftKey: Row = _ private[this] var rightKey: Row = _ - private[this] var currentlMatches: CompactBuffer[Row] = _ - private[this] var currentrMatches: CompactBuffer[Row] = _ - private[this] var currentlPosition: Int = -1 - private[this] var currentrPosition: Int = -1 + private[this] var leftMatches: CompactBuffer[Row] = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var leftPosition: Int = -1 + private[this] var rightPosition: Int = -1 - override final def hasNext: Boolean = currentlPosition != -1 || nextMatchingPair + override final def hasNext: Boolean = leftPosition != -1 || nextMatchingPair override final def next(): Row = { - if (!hasNext) { - return null - } - val joinedRow = - joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition)) - currentrPosition += 1 - if (currentrPosition >= currentrMatches.size) { - currentlPosition += 1 - currentrPosition = 0 - if (currentlPosition >= currentlMatches.size) { - currentlPosition = -1 + if (hasNext) { + val joinedRow = joinRow(leftMatches(leftPosition), rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + leftPosition += 1 + rightPosition = 0 + if (leftPosition >= leftMatches.size) { + leftPosition = -1 + } } + joinedRow + } else { + // according to Scala doc, this is undefined + null } - joinedRow } private def fetchLeft() = { @@ -104,13 +105,12 @@ case class SortMergeJoin( } } - private def fetchFirst() = { + private def initialize() = { fetchLeft() fetchRight() - currentrPosition = 0 } // initialize iterator - fetchFirst() + initialize() /** * Searches the left/right iterator for the next rows that matches. @@ -119,50 +119,42 @@ case class SortMergeJoin( * of tuples. */ private def nextMatchingPair(): Boolean = { - if (currentlPosition > -1) { - true - } else { - currentlPosition = -1 - currentlMatches = null + if (leftPosition == -1) { + leftMatches = null var stop: Boolean = false while (!stop && leftElement != null && rightElement != null) { - if (ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull) { - stop = true - } else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { + stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull + if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { fetchRight() - } else { // if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) + } else if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) { fetchLeft() } } - currentrMatches = new CompactBuffer[Row]() + rightMatches = new CompactBuffer[Row]() while (stop && rightElement != null) { - currentrMatches += rightElement + rightMatches += rightElement fetchRight() - if (ordering.compare(leftKey, rightKey) != 0) { - stop = false - } + // exit loop when run out of right matches + stop = ordering.compare(leftKey, rightKey) == 0 } - if (currentrMatches.size > 0) { + if (rightMatches.size > 0) { stop = false - currentlMatches = new CompactBuffer[Row]() + leftMatches = new CompactBuffer[Row]() val leftMatch = leftKey.copy() while (!stop && leftElement != null) { - currentlMatches += leftElement + leftMatches += leftElement fetchLeft() - if (ordering.compare(leftKey, leftMatch) != 0) { - stop = true - } + // exit loop when run out of left matches + stop = ordering.compare(leftKey, leftMatch) != 0 } } - if (currentlMatches == null) { - false - } else { - currentlPosition = 0 - currentrPosition = 0 - true + if (leftMatches != null) { + leftPosition = 0 + rightPosition = 0 } } + leftPosition > -1 } } } From 07ce92f4e65e11b1f0cfa5d1688f68fc4b2b0f40 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 1 Apr 2015 20:16:37 -0700 Subject: [PATCH 09/32] fix ArrayIndexOutOfBound --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ec6a05542659d..ddddd8c9a14ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -44,7 +44,9 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil - private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending)) + private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map { + case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending) + } private val ordering: RowOrdering = new RowOrdering(orders, left.output) @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) From 925203b5c1419b3b14f9138ebd283565fa9adc65 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 2 Apr 2015 00:49:43 -0700 Subject: [PATCH 10/32] address comments --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 ++-- .../sql/execution/joins/SortMergeJoin.scala | 5 +---- .../scala/org/apache/spark/sql/JoinSuite.scala | 16 +++++++++++++--- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e2e07c1b804a3..842cd1b7d58fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -145,10 +145,10 @@ private[sql] class SQLConf extends Serializable { getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** - * By default it will choose sort merge join. + * By default not choose sort merge join. */ private[spark] def autoSortMergeJoin: Boolean = - getConf(AUTO_SORTMERGEJOIN, true.toString).toBoolean + getConf(AUTO_SORTMERGEJOIN, false.toString).toBoolean /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ddddd8c9a14ca..32689f5ca2591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -140,14 +140,11 @@ case class SortMergeJoin( stop = ordering.compare(leftKey, rightKey) == 0 } if (rightMatches.size > 0) { - stop = false leftMatches = new CompactBuffer[Row]() val leftMatch = leftKey.copy() - while (!stop && leftElement != null) { + while (ordering.compare(leftKey, leftMatch) == 0 && leftElement != null) { leftMatches += leftElement fetchLeft() - // exit loop when run out of left matches - stop = ordering.compare(leftKey, leftMatch) != 0 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index bba2f223c55dc..a2c9778883389 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: ShuffledHashJoin => j case j: SortMergeJoin => j } @@ -63,6 +64,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() + val AUTO_SORTMERGEJOIN: Boolean = conf.autoSortMergeJoin + conf.setConf("spark.sql.autoSortMergeJoin", "false") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -76,9 +79,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]), @@ -92,6 +95,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + conf.setConf("spark.sql.autoSortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) } test("broadcasted hash join operator selection") { From 068c35d08121696b3c71026985a951613bdcedec Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 9 Apr 2015 22:15:48 -0700 Subject: [PATCH 11/32] fix new style and add some tests --- .../plans/physical/partitioning.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../sql/execution/joins/SortMergeJoin.scala | 4 +- .../SortMergeCompatibilitySuite.scala | 132 ++++++++++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 2 - 5 files changed, 138 insertions(+), 8 deletions(-) create mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 147b48047f530..e0f981ef37960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -186,9 +186,9 @@ case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: I extends Expression with Partitioning { - override def children = expressions - override def nullable = false - override def dataType = IntegerType + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = expressions.toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 72f41e4bd7685..c777af9da21bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -94,7 +94,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.autoSortMergeJoin => val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, Inner, planLater(left), planLater(right)) + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 32689f5ca2591..259e7ab264e29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -33,7 +34,6 @@ import org.apache.spark.util.collection.CompactBuffer case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - joinType: JoinType, left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -52,7 +52,7 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - override def execute() = { + override def execute(): RDD[Row] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala new file mode 100644 index 0000000000000..f49555c9142b1 --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -0,0 +1,132 @@ +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join is true. + */ +class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "false") + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index dc1d9fbd299e5..ccd0e5aa51f95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -144,7 +144,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { expectedAnswer: Seq[Row], ct: ClassTag[_]) = { before() - conf.setConf("spark.sql.autoSortMergeJoin", "false") var df = sql(query) @@ -179,7 +178,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") } - conf.setConf("spark.sql.autoSortMergeJoin", "true") after() } From 645c70b24762d14d88e19d708e53a384a3daa2a1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 10 Apr 2015 03:44:52 -0700 Subject: [PATCH 12/32] address comments using sort --- .../plans/physical/partitioning.scala | 49 ------------------- .../apache/spark/sql/execution/Exchange.scala | 38 +++++--------- .../spark/sql/execution/SparkPlan.scala | 6 +++ .../sql/execution/joins/SortMergeJoin.scala | 13 ++++- .../SortMergeCompatibilitySuite.scala | 13 +++++ 5 files changed, 42 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e0f981ef37960..288c11f69fe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,21 +75,6 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } -/** - * Represents data where tuples have been ordered according to the `clustering` - * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted - * by the expressions. - */ -case class ClusteredOrderedDistribution(clustering: Seq[Expression]) - extends Distribution { - require( - clustering != Nil, - "The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " + - "An AllTuples should be used to represent a distribution that only has " + - "a single partition.") -} - sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -177,40 +162,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. And rows within the same partition are sorted by the expressions. - */ -case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression - with Partitioning { - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: DataType = IntegerType - - private[this] lazy val clusteringSet = expressions.toSet - - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredOrderedDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case _ => false - } - - override def compatibleWith(other: Partitioning) = other match { - case BroadcastPartitioning => true - case h: HashSortedPartitioning if h == this => true - case _ => false - } - - override def eval(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") -} - /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner the following diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 63d2e92697c04..c89b2a068351b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -72,29 +72,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) - case HashSortedPartitioning(expressions, numPartitions) => - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[Row, Row]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) - } - } - val sortingExpressions = expressions.zipWithIndex.map { - case (exp, index) => - new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending) - } - val ordering = new RowOrdering(sortingExpressions, child.output) - val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - shuffled.map(_._2) - case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} @@ -184,6 +161,11 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Check if the partitioning we want to ensure is the same as the child's output + // partitioning. If so, we do not need to add the Exchange operator. + def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan = + if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child + if (meetsRequirements && compatible) { operator } else { @@ -195,14 +177,18 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), child) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (ClusteredOrderedDistribution(clustering), child) => - addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child) case (OrderedDistribution(ordering), child) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, child) => child case (dist, _) => sys.error(s"Don't know how to ensure $dist") } - operator.withNewChildren(repartitionedChildren) + val reorderedChildren = operator.requiredInPartitionOrdering.zip(repartitionedChildren).map { + case (Nil, child) => + child + case (ordering, child) => + addSortIfNecessary(ordering, child) + } + operator.withNewChildren(reorderedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d239637cd4b4e..41d80969e376b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 259e7ab264e29..fd65320d55139 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -39,16 +39,25 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = HashSortedPartitioning(leftKeys, 0) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = - ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map { case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending) } private val ordering: RowOrdering = new RowOrdering(orders, left.output) + private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map { + k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending) + } + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left) + + override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f49555c9142b1..3e08a0ce8c003 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -70,6 +70,19 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "auto_sortmerge_join_7", "auto_sortmerge_join_8", "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", "join0", "join1", "join10", From a28277ff1e8096964f2ec92e288b28475b870699 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 10 Apr 2015 03:51:11 -0700 Subject: [PATCH 13/32] fix style --- .../org/apache/spark/sql/execution/Exchange.scala | 12 ++++++------ .../hive/execution/SortMergeCompatibilitySuite.scala | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index c89b2a068351b..dc7e3a5c41070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -182,12 +182,12 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case (UnspecifiedDistribution, child) => child case (dist, _) => sys.error(s"Don't know how to ensure $dist") } - val reorderedChildren = operator.requiredInPartitionOrdering.zip(repartitionedChildren).map { - case (Nil, child) => - child - case (ordering, child) => - addSortIfNecessary(ordering, child) - } + val reorderedChildren = + operator.requiredInPartitionOrdering.zip(repartitionedChildren).map { + case (Nil, child) => child + case (ordering, child) => + addSortIfNecessary(ordering, child) + } operator.withNewChildren(reorderedChildren) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index 3e08a0ce8c003..bf01fdf1b92e5 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -142,4 +142,4 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "join_reorder4", "join_star" ) -} \ No newline at end of file +} From 47455c93daf749578309f21de88e058b3d8214d2 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 10 Apr 2015 03:53:31 -0700 Subject: [PATCH 14/32] add apache license ... --- .../execution/SortMergeCompatibilitySuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index bf01fdf1b92e5..32543f2f3a8b2 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.hive.execution import org.apache.spark.sql.SQLConf From 171001fea71669a009e440dfb6caf927aecd924b Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sat, 11 Apr 2015 19:26:01 +0800 Subject: [PATCH 15/32] change default outputordering --- .../apache/spark/sql/execution/Aggregate.scala | 2 ++ .../apache/spark/sql/execution/Exchange.scala | 4 ++-- .../apache/spark/sql/execution/SparkPlan.scala | 1 + .../spark/sql/execution/basicOperators.scala | 8 ++++++++ .../scala/org/apache/spark/sql/JoinSuite.scala | 17 ++++++++++------- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 18b1ba4c5c4b9..296c71df6a11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -60,6 +60,8 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def outputOrdering: Seq[SortOrder] = Nil + /** * An aggregate that needs to be computed for each row in a group. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index dc7e3a5c41070..e6ac2926320f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -161,8 +161,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. + // Check if the ordering we want to ensure is the same as the child's output + // ordering. If so, we do not need to add the Sort operator. def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan = if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 41d80969e376b..b3252da2df201 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -183,6 +183,7 @@ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { self: Product => override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1f5251a20376f..e13a3699318cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,6 +70,8 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: override def execute(): RDD[Row] = { child.execute().map(_.copy()).sample(withReplacement, fraction, seed) } + + override def outputOrdering: Seq[SortOrder] = Nil } /** @@ -146,6 +148,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -171,6 +175,8 @@ case class Sort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -201,6 +207,8 @@ case class ExternalSort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a2c9778883389..826db143a9211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -95,13 +95,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - conf.setConf("spark.sql.autoSortMergeJoin", "true") - Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) + try { + conf.setConf("spark.sql.autoSortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) + } } test("broadcasted hash join operator selection") { From 3af6ba546b016d8bd2fa8d9a729a7bc9993e8e50 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sat, 11 Apr 2015 21:15:48 +0800 Subject: [PATCH 16/32] use buffer for only one side --- .../sql/execution/joins/SortMergeJoin.scala | 49 ++++++++----------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index fd65320d55139..7e7b692d401cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -73,22 +73,23 @@ case class SortMergeJoin( private[this] var rightElement: Row = _ private[this] var leftKey: Row = _ private[this] var rightKey: Row = _ - private[this] var leftMatches: CompactBuffer[Row] = _ private[this] var rightMatches: CompactBuffer[Row] = _ - private[this] var leftPosition: Int = -1 private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ - override final def hasNext: Boolean = leftPosition != -1 || nextMatchingPair + override final def hasNext: Boolean = nextMatchingPair() override final def next(): Row = { if (hasNext) { - val joinedRow = joinRow(leftMatches(leftPosition), rightMatches(rightPosition)) + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) rightPosition += 1 if (rightPosition >= rightMatches.size) { - leftPosition += 1 rightPosition = 0 - if (leftPosition >= leftMatches.size) { - leftPosition = -1 + fetchLeft() + if (leftElement == null || ordering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null } } joinedRow @@ -130,9 +131,7 @@ case class SortMergeJoin( * of tuples. */ private def nextMatchingPair(): Boolean = { - if (leftPosition == -1) { - leftMatches = null - var stop: Boolean = false + if (!stop && rightElement != null) { while (!stop && leftElement != null && rightElement != null) { stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { @@ -142,27 +141,21 @@ case class SortMergeJoin( } } rightMatches = new CompactBuffer[Row]() - while (stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - // exit loop when run out of right matches - stop = ordering.compare(leftKey, rightKey) == 0 - } - if (rightMatches.size > 0) { - leftMatches = new CompactBuffer[Row]() - val leftMatch = leftKey.copy() - while (ordering.compare(leftKey, leftMatch) == 0 && leftElement != null) { - leftMatches += leftElement - fetchLeft() + if (stop) { + stop = false + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + // exit loop when run out of right matches + stop = ordering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey } - } - - if (leftMatches != null) { - leftPosition = 0 - rightPosition = 0 } } - leftPosition > -1 + rightMatches != null && rightMatches.size > 0 } } } From 078d69b1894ad6c895ed9bb101b9596a03d479fa Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 12 Apr 2015 15:25:21 +0800 Subject: [PATCH 17/32] address comments: add comments, do sort in shuffle, and others --- .../scala/org/apache/spark/sql/SQLConf.scala | 15 ++--- .../spark/sql/execution/Aggregate.scala | 2 - .../apache/spark/sql/execution/Exchange.scala | 55 +++++++++++-------- .../spark/sql/execution/SparkPlan.scala | 5 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 21 +++++-- .../sql/execution/joins/SortMergeJoin.scala | 49 ++++++++++------- .../org/apache/spark/sql/JoinSuite.scala | 4 +- .../SortMergeCompatibilitySuite.scala | 4 +- 9 files changed, 93 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 842cd1b7d58fb..02867b19a2db9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -27,7 +27,6 @@ private[spark] object SQLConf { val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" @@ -46,6 +45,7 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -123,6 +123,13 @@ private[sql] class SQLConf extends Serializable { /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + /** + * Sort merge join would sort the two side of join first, and then iterate both sides together + * only once to get all matches. Using sort merge join can save a lot of memory usage compared + * to HashJoin. + */ + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster @@ -144,12 +151,6 @@ private[sql] class SQLConf extends Serializable { private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt - /** - * By default not choose sort merge join. - */ - private[spark] def autoSortMergeJoin: Boolean = - getConf(AUTO_SORTMERGEJOIN, false.toString).toBoolean - /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, * it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 296c71df6a11e..18b1ba4c5c4b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -60,8 +60,6 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - override def outputOrdering: Seq[SortOrder] = Nil - /** * An aggregate that needs to be computed for each row in a group. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index e6ac2926320f5..d5b75f796d320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -32,7 +32,11 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + newPartitioning: Partitioning, + child: SparkPlan, + sort: Boolean = false) + extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning @@ -68,7 +72,16 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) + val shuffled = sort match { + case false => new ShuffledRDD[Row, Row, Row](rdd, part) + case true => + val sortingExpressions = expressions.zipWithIndex.map { + case (exp, index) => + new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending) + } + val ordering = new RowOrdering(sortingExpressions, child.output) + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -158,13 +171,15 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // Check if the partitioning we want to ensure is the same as the child's output // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = - if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child - - // Check if the ordering we want to ensure is the same as the child's output - // ordering. If so, we do not need to add the Sort operator. - def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan = - if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child + def addExchangeIfNecessary( + partitioning: Partitioning, + child: SparkPlan, + rowOrdering: Option[Ordering[Row]] = None): SparkPlan = + if (child.outputPartitioning != partitioning) { + Exchange(partitioning, child, sort = child.outputOrdering != rowOrdering) + } else { + child + } if (meetsRequirements && compatible) { operator @@ -172,23 +187,19 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { - case (AllTuples, child) => + val repartitionedChildren = operator.requiredChildDistribution.zip( + operator.children.zip(operator.requiredChildOrdering) + ).map { + case (AllTuples, (child, _)) => addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), child) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (OrderedDistribution(ordering), child) => + case (ClusteredDistribution(clustering), (child, rowOrdering)) => + addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering) + case (OrderedDistribution(ordering), (child, _)) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, child) => child + case (UnspecifiedDistribution, (child, _)) => child case (dist, _) => sys.error(s"Don't know how to ensure $dist") } - val reorderedChildren = - operator.requiredInPartitionOrdering.zip(repartitionedChildren).map { - case (Nil, child) => child - case (ordering, child) => - addSortIfNecessary(ordering, child) - } - operator.withNewChildren(reorderedChildren) + operator.withNewChildren(repartitionedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b3252da2df201..748c478a5a839 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ Seq.fill(children.size)(UnspecifiedDistribution) /** Specifies how data is ordered in each partition. */ - def outputOrdering: Seq[SortOrder] = Nil + def outputOrdering: Option[Ordering[Row]] = None /** Specifies sort order for each partition requirements on the input data for this operator. */ - def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None) /** * Runs this query returning the result as an RDD. @@ -183,7 +183,6 @@ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { self: Product => override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering } private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c777af9da21bd..83ec39a4b7f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -92,7 +92,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoSortMergeJoin => + if sqlContext.conf.sortMergeJoinEnabled => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e13a3699318cb..c723c502ab90b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,6 +41,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends val resuableProjection = buildProjection() iter.map(resuableProjection) } + + /** + * outputOrdering of Project is not always same with child's outputOrdering if the certain + * key is pruned, however, if the key is pruned then we must not require child using this + * ordering from upper layer, only if the ordering would not be changed by a negative, there + * would be a way to keep the ordering. + * TODO: we may utilize this feature later to avoid some unnecessary sorting. + */ + override def outputOrdering: Option[Ordering[Row]] = None } /** @@ -55,6 +64,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } + + override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering } /** @@ -70,8 +81,6 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: override def execute(): RDD[Row] = { child.execute().map(_.copy()).sample(withReplacement, fraction, seed) } - - override def outputOrdering: Seq[SortOrder] = Nil } /** @@ -104,6 +113,8 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition + override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def executeCollect(): Array[Row] = child.executeTake(limit) override def execute(): RDD[Row] = { @@ -149,7 +160,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) - override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) } /** @@ -176,7 +187,7 @@ case class Sort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) } /** @@ -208,7 +219,7 @@ case class ExternalSort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 7e7b692d401cd..a1f0805a0ab92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.joins +import java.util.NoSuchElementException + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -47,16 +49,16 @@ case class SortMergeJoin( private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map { case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending) } - private val ordering: RowOrdering = new RowOrdering(orders, left.output) + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = new RowOrdering(orders) - private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map { - k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending) - } + private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] = + newOrdering(keys.map(SortOrder(_, Ascending)), side.output) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left) + override def outputOrdering: Option[Ordering[Row]] = Some(requiredOrders(leftKeys, left)) - override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil + override def requiredChildOrdering: Seq[Option[Ordering[Row]]] = + Some(requiredOrders(leftKeys, left)) :: Some(requiredOrders(rightKeys, right)) :: Nil @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) @@ -78,24 +80,28 @@ case class SortMergeJoin( private[this] var stop: Boolean = false private[this] var matchKey: Row = _ + // initialize iterator + initialize() + override final def hasNext: Boolean = nextMatchingPair() override final def next(): Row = { if (hasNext) { + // we are using the buffered right rows and run down left iterator val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) rightPosition += 1 if (rightPosition >= rightMatches.size) { rightPosition = 0 fetchLeft() - if (leftElement == null || ordering.compare(leftKey, matchKey) != 0) { + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { stop = false rightMatches = null } } joinedRow } else { - // according to Scala doc, this is undefined - null + // no more result + throw new NoSuchElementException } } @@ -121,33 +127,36 @@ case class SortMergeJoin( fetchLeft() fetchRight() } - // initialize iterator - initialize() /** - * Searches the left/right iterator for the next rows that matches. + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. * - * @return true if the search is successful, and false if the left/right iterator runs out - * of tuples. + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. */ private def nextMatchingPair(): Boolean = { if (!stop && rightElement != null) { + // run both side to get the first match pair while (!stop && leftElement != null && rightElement != null) { - stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull - if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { fetchRight() - } else if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) { + } else if (comparing < 0 || leftKey.anyNull) { fetchLeft() } } rightMatches = new CompactBuffer[Row]() if (stop) { stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match while (!stop && rightElement != null) { rightMatches += rightElement fetchRight() - // exit loop when run out of right matches - stop = ordering.compare(leftKey, rightKey) != 0 + stop = keyOrdering.compare(leftKey, rightKey) != 0 } if (rightMatches.size > 0) { rightPosition = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 826db143a9211..0524085edb98c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -64,7 +64,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() - val AUTO_SORTMERGEJOIN: Boolean = conf.autoSortMergeJoin + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled conf.setConf("spark.sql.autoSortMergeJoin", "false") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -103,7 +103,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) + conf.setConf("spark.sql.autoSortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index 32543f2f3a8b2..65d070bd3cbde 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "true") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") } override def afterAll() { - TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "false") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") super.afterAll() } From 00a443073debcdc37acf88477a30e0f9f6212434 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 12 Apr 2015 15:34:36 +0800 Subject: [PATCH 18/32] fix bug --- .../scala/org/apache/spark/sql/execution/Exchange.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d5b75f796d320..4c92aad54a62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -174,12 +174,14 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl def addExchangeIfNecessary( partitioning: Partitioning, child: SparkPlan, - rowOrdering: Option[Ordering[Row]] = None): SparkPlan = - if (child.outputPartitioning != partitioning) { - Exchange(partitioning, child, sort = child.outputOrdering != rowOrdering) + rowOrdering: Option[Ordering[Row]] = None): SparkPlan = { + val needSort = child.outputOrdering != rowOrdering + if (child.outputPartitioning != partitioning || needSort) { + Exchange(partitioning, child, sort = needSort) } else { child } + } if (meetsRequirements && compatible) { operator From 61d7f4961b35cc8e32d6f5b3c14ef89d672ca8cf Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 12 Apr 2015 15:40:24 +0800 Subject: [PATCH 19/32] add omitted comment --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 83ec39a4b7f38..519ef5d93154c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled => From 2875ef25129780455835ddf516d776ae53aba91c Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 12 Apr 2015 22:04:17 +0800 Subject: [PATCH 20/32] fix changed configuration --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 1 + sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 4c92aad54a62a..7bdfbb8ec4e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -177,6 +177,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl rowOrdering: Option[Ordering[Row]] = None): SparkPlan = { val needSort = child.outputOrdering != rowOrdering if (child.outputPartitioning != partitioning || needSort) { + // TODO: if only needSort, we need only sort each partition instead of an Exchange Exchange(partitioning, child, sort = needSort) } else { child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0524085edb98c..2429e0a8e3655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -65,7 +65,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { cacheManager.clearCache() val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled - conf.setConf("spark.sql.autoSortMergeJoin", "false") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -96,14 +95,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.autoSortMergeJoin", "true") + conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.autoSortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } From 8681d73bb53d7747d3122ae0f17ba5acdb4ea9a4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 13 Apr 2015 22:45:48 -0700 Subject: [PATCH 21/32] refactor Exchange and fix copy for sorting --- .../org/apache/spark/sql/execution/Exchange.scala | 10 +++++----- .../apache/spark/sql/execution/SparkStrategies.scala | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7bdfbb8ec4e7a..ba866357f8bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -34,8 +34,8 @@ import org.apache.spark.util.MutablePair @DeveloperApi case class Exchange( newPartitioning: Partitioning, - child: SparkPlan, - sort: Boolean = false) + sort: Boolean, + child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning @@ -59,7 +59,7 @@ case class Exchange( // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -178,7 +178,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl val needSort = child.outputOrdering != rowOrdering if (child.outputPartitioning != partitioning || needSort) { // TODO: if only needSort, we need only sort each partition instead of an Exchange - Exchange(partitioning, child, sort = needSort) + Exchange(partitioning, sort = needSort, child) } else { child } @@ -197,7 +197,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), (child, rowOrdering)) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering) - case (OrderedDistribution(ordering), (child, _)) => + case (OrderedDistribution(ordering), (child, None)) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, (child, _)) => child case (dist, _) => sys.error(s"Don't know how to ensure $dist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 519ef5d93154c..c6ff8c30c24e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -307,7 +307,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil From 6e897dd7244e5c35e37368703d25e242093522cf Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 01:16:35 -0700 Subject: [PATCH 22/32] hide boundReference from manually construct RowOrdering for key compare in smj --- .../org/apache/spark/sql/catalyst/expressions/rows.scala | 9 ++++++++- .../apache/spark/sql/execution/joins/SortMergeJoin.scala | 5 +---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..6a32244bd03b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} +import org.apache.spark.sql.types.{DataType, StructType, NativeType} /** @@ -232,3 +232,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index a1f0805a0ab92..048251c4c1f91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -46,11 +46,8 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map { - case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending) - } // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = new RowOrdering(orders) + private val keyOrdering: RowOrdering = RowOrdering.getOrderingFromDataTypes(leftKeys.map(_.dataType)) private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] = newOrdering(keys.map(SortOrder(_, Ascending)), side.output) From c8e82a36782377ff322281be2528ffe74f2de655 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 01:35:01 -0700 Subject: [PATCH 23/32] fix style --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 048251c4c1f91..8262a3c14b0b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -47,7 +47,8 @@ case class SortMergeJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.getOrderingFromDataTypes(leftKeys.map(_.dataType)) + private val keyOrdering: RowOrdering = + RowOrdering.getOrderingFromDataTypes(leftKeys.map(_.dataType)) private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] = newOrdering(keys.map(SortOrder(_, Ascending)), side.output) From b1982789be082d2f75e858feae94c80cce93b552 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 18:20:45 -0700 Subject: [PATCH 24/32] inherit ordering in project --- .../scala/org/apache/spark/sql/execution/Exchange.scala | 5 ++++- .../org/apache/spark/sql/execution/basicOperators.scala | 6 ++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index ba866357f8bf0..d67e5c40a4727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,7 +29,10 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair /** - * :: DeveloperApi :: + * Shuffle data according to a new partition rule, and sort inside each partition if necessary. + * @param newPartitioning The new partitioning way that required by parent + * @param sort Whether we will sort inside each partition + * @param child Child operator */ @DeveloperApi case class Exchange( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index c723c502ab90b..cbfcca1ea1546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -45,11 +45,9 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends /** * outputOrdering of Project is not always same with child's outputOrdering if the certain * key is pruned, however, if the key is pruned then we must not require child using this - * ordering from upper layer, only if the ordering would not be changed by a negative, there - * would be a way to keep the ordering. - * TODO: we may utilize this feature later to avoid some unnecessary sorting. + * ordering from upper layer, so it is fine to keep it to avoid some unnecessary sorting. */ - override def outputOrdering: Option[Ordering[Row]] = None + override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering } /** From 7ddd65627f4bbb40e2f7c95d55162be619d26166 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 14 Apr 2015 19:08:16 -0700 Subject: [PATCH 25/32] Cleanup addition of ordering requirements --- .../spark/sql/catalyst/expressions/rows.scala | 2 +- .../plans/physical/partitioning.scala | 13 ++ .../org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 147 ++++++++++++------ .../spark/sql/execution/SparkPlan.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 17 +- .../sql/execution/joins/SortMergeJoin.scala | 14 +- 8 files changed, 132 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 6a32244bd03b3..d45230a02e501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -234,7 +234,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { } object RowOrdering { - def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering = + def forSchema(dataTypes: Seq[DataType]): RowOrdering = new RowOrdering(dataTypes.zipWithIndex.map { case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..fb4217a44807b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39dd14e796f06..f35b1cfb9ac43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1080,7 +1080,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil } protected[sql] def openSession(): SQLSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d67e5c40a4727..1901fde5f075e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair +object Exchange { + /** Returns true when the ordering expressions are a subset of the key. */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** - * Shuffle data according to a new partition rule, and sort inside each partition if necessary. - * @param newPartitioning The new partitioning way that required by parent - * @param sort Whether we will sort inside each partition - * @param child Child operator + * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi case class Exchange( newPartitioning: Partitioning, - sort: Boolean, + newOrdering: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning + override def outputOrdering = newOrdering + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ @@ -51,6 +60,20 @@ case class Exchange( private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => @@ -62,7 +85,9 @@ case class Exchange( // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -75,16 +100,12 @@ case class Exchange( } } val part = new HashPartitioner(numPartitions) - val shuffled = sort match { - case false => new ShuffledRDD[Row, Row, Row](rdd, part) - case true => - val sortingExpressions = expressions.zipWithIndex.map { - case (exp, index) => - new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending) - } - val ordering = new RowOrdering(sortingExpressions, child.output) - new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) - } + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -102,7 +123,12 @@ case class Exchange( implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -135,27 +161,35 @@ case class Exchange( * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. def meetsRequirements: Boolean = - !operator.requiredChildDistribution.zip(operator.children).map { + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } - // Check if outputPartitionings of children are compatible with each other. + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child + } + + + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary( + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( partitioning: Partitioning, - child: SparkPlan, - rowOrdering: Option[Ordering[Row]] = None): SparkPlan = { - val needSort = child.outputOrdering != rowOrdering - if (child.outputPartitioning != partitioning || needSort) { - // TODO: if only needSort, we need only sort each partition instead of an Exchange - Exchange(partitioning, sort = needSort, child) + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) } else { - child + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } + + val withSort = if (needSort) { + Sort(rowOrdering, global = false, withShuffle) + } else { + withShuffle + } + + withSort } } - if (meetsRequirements && compatible) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip( - operator.children.zip(operator.requiredChildOrdering) - ).map { - case (AllTuples, (child, _)) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), (child, rowOrdering)) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering) - case (OrderedDistribution(ordering), (child, None)) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, (child, _)) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), Nil, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + Sort(rowOrdering, global = false, child) + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 748c478a5a839..9c5ecb984245b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ Seq.fill(children.size)(UnspecifiedDistribution) /** Specifies how data is ordered in each partition. */ - def outputOrdering: Option[Ordering[Row]] = None + def outputOrdering: Seq[SortOrder] = Nil /** Specifies sort order for each partition requirements on the input data for this operator. */ - def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None) + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** * Runs this query returning the result as an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c6ff8c30c24e7..e680b169c5975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -308,7 +308,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange( - HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index cbfcca1ea1546..a13c7f7c8611b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -42,12 +42,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends iter.map(resuableProjection) } - /** - * outputOrdering of Project is not always same with child's outputOrdering if the certain - * key is pruned, however, if the key is pruned then we must not require child using this - * ordering from upper layer, so it is fine to keep it to avoid some unnecessary sorting. - */ - override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -63,7 +58,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { iter.filter(conditionEvaluator) } - override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -111,7 +106,7 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition - override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def executeCollect(): Array[Row] = child.executeTake(limit) @@ -158,7 +153,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) - override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -185,7 +180,7 @@ case class Sort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -217,7 +212,7 @@ case class ExternalSort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 8262a3c14b0b3..e2393ba3bb8f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -48,19 +48,19 @@ case class SortMergeJoin( // this is to manually construct an ordering that can be used to compare keys from both sides private val keyOrdering: RowOrdering = - RowOrdering.getOrderingFromDataTypes(leftKeys.map(_.dataType)) + RowOrdering.forSchema(leftKeys.map(_.dataType)) - private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] = - newOrdering(keys.map(SortOrder(_, Ascending)), side.output) + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) - override def outputOrdering: Option[Ordering[Row]] = Some(requiredOrders(leftKeys, left)) - - override def requiredChildOrdering: Seq[Option[Ordering[Row]]] = - Some(requiredOrders(leftKeys, left)) :: Some(requiredOrders(rightKeys, right)) :: Nil + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + override def execute(): RDD[Row] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) From 54928840e96e6d503defcbf989d74d587fcd38ff Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 14 Apr 2015 19:29:36 -0700 Subject: [PATCH 26/32] copy when ordering --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 1901fde5f075e..9842848c5cc7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -110,7 +110,7 @@ case class Exchange( shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => From 952168a3dffca5ae8780f872ca50a73e88bf4c86 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 14 Apr 2015 19:34:09 -0700 Subject: [PATCH 27/32] add type --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 9842848c5cc7e..a56f6a100a515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -50,7 +50,7 @@ case class Exchange( override def outputPartitioning: Partitioning = newPartitioning - override def outputOrdering = newOrdering + override def outputOrdering: Seq[SortOrder] = newOrdering override def output: Seq[Attribute] = child.output From ec8061b7f36b87c883af111438ac9ff0304050d7 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 22:20:23 -0700 Subject: [PATCH 28/32] minor change --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 1 - .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index a56f6a100a515..85587befb8758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -188,7 +188,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (required, child) => required.nonEmpty && required != child } - // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index e2393ba3bb8f6..b5123668ba11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -47,8 +47,7 @@ case class SortMergeJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = - RowOrdering.forSchema(leftKeys.map(_.dataType)) + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) From f515cd29bbe7765eefbb185ad26b5dbb9e2d7380 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 22:48:00 -0700 Subject: [PATCH 29/32] yin's comment: outputOrdering, join suite refine --- .../org/apache/spark/sql/execution/Exchange.scala | 4 ++-- .../test/scala/org/apache/spark/sql/JoinSuite.scala | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 85587befb8758..dad8c3de4e1c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -185,7 +185,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // True iff any of the children are incorrectly sorted. def needsAnySort: Boolean = operator.requiredChildOrdering.zip(operator.children).exists { - case (required, child) => required.nonEmpty && required != child + case (required, child) => required.nonEmpty && required != child.outputOrdering } // True iff outputPartitionings of children are compatible with each other. @@ -233,7 +233,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } } - if (meetsRequirements && compatible && !needsAnySort) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2429e0a8e3655..e6ead984a435f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,7 +51,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j - case j: ShuffledHashJoin => j case j: SortMergeJoin => j } @@ -110,11 +109,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { cacheManager.clearCache() sql("CACHE TABLE testData") + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } sql("UNCACHE TABLE testData") } From f91a2aecf795b2a2b2b834bf69b21875ef6f0b6f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 23:09:00 -0700 Subject: [PATCH 30/32] yin's comment: use external sort if option is enabled, add comments --- .../apache/spark/sql/execution/Exchange.scala | 17 ++++++++++++++--- .../spark/sql/execution/basicOperators.scala | 2 -- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index dad8c3de4e1c2..2599765f24aa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,7 +29,10 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair object Exchange { - /** Returns true when the ordering expressions are a subset of the key. */ + /** + * Returns true when the ordering expressions are a subset of the key. + * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. + */ def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) } @@ -224,7 +227,11 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } val withSort = if (needSort) { - Sort(rowOrdering, global = false, withShuffle) + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } } else { withShuffle } @@ -253,7 +260,11 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (UnspecifiedDistribution, Seq(), child) => child case (UnspecifiedDistribution, rowOrdering, child) => - Sort(rowOrdering, global = false, child) + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, child) + } else { + Sort(rowOrdering, global = false, child) + } case (dist, ordering, _) => sys.error(s"Don't know how to ensure $dist with ordering $ordering") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a13c7f7c8611b..8cc1b527bd682 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -106,8 +106,6 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def executeCollect(): Array[Row] = child.executeTake(limit) override def execute(): RDD[Row] = { From 5049d882fbfcf9b7c63e95ec20d3a15310068752 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 23:33:58 -0700 Subject: [PATCH 31/32] propagate rowOrdering for RangePartitioning --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 2599765f24aa5..518fc9e57c708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -255,7 +255,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (ClusteredDistribution(clustering), rowOrdering, child) => addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), Nil, child) + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) case (UnspecifiedDistribution, Seq(), child) => child From 2493b9f9548c4a63a3d31dc600588ac65968b611 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 14 Apr 2015 23:41:25 -0700 Subject: [PATCH 32/32] fix style --- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e6ead984a435f..037d392c1f929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -113,14 +113,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)