From 7ff44e18f575f907cc6cdf7d667bc985b8834d1c Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 May 2016 16:06:46 +0800 Subject: [PATCH 1/7] init commit --- .../spark/sql/execution/QueryExecution.scala | 19 +- .../execution/adaptive/FragmentInput.scala | 46 +++ .../execution/adaptive/QueryFragment.scala | 345 ++++++++++++++++++ .../adaptive/QueryFragmentTransformer.scala | 91 +++++ .../spark/sql/execution/adaptive/utils.scala | 153 ++++++++ .../apache/spark/sql/internal/SQLConf.scala | 7 + .../spark/sql/internal/SessionState.scala | 9 + 7 files changed, 669 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3e772286e0e55..74ed3338b50b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.adaptive.QueryFragmentTransformer import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -79,7 +80,12 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = + if (sparkSession.sessionState.conf.adaptiveExecution2Enabled) { + prepareForAdaptiveExecution(sparkPlan) + } else { + prepareForExecution(sparkPlan) + } /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() @@ -100,6 +106,17 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf)) + protected def prepareForAdaptiveExecution(plan: SparkPlan): SparkPlan = { + preparationsForAdaptive.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + protected def preparationsForAdaptive: Seq[Rule[SparkPlan]] = Seq( + python.ExtractPythonUDFs, + PlanSubqueries(sparkSession), + EnsureRequirements(sparkSession.sessionState.conf), + ReuseExchange(sparkSession.sessionState.conf), + QueryFragmentTransformer(sparkSession.sessionState.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala new file mode 100644 index 0000000000000..14c852acc2d26 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala @@ -0,0 +1,46 @@ +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{LeafExecNode, ShuffledRowRDD, SparkPlan} + +/** + * FragmentInput is the leaf node of parent fragment that connect with an child fragment. + */ +case class FragmentInput(@transient childFragment: QueryFragment) extends LeafExecNode { + + private[this] var optimized: Boolean = false + + private[this] var inputPlan: SparkPlan = null + + private[this] var shuffledRdd: ShuffledRowRDD = null + + override def output: Seq[Attribute] = inputPlan.output + + private[sql] def setOptimized() = { + this.optimized = true + } + + private[sql] def isOptimized(): Boolean = this.optimized + + private[sql] def setShuffleRdd(shuffledRdd: ShuffledRowRDD) = { + this.shuffledRdd = shuffledRdd + } + + private[sql] def setInputPlan(inputPlan: SparkPlan) = { + this.inputPlan = inputPlan + } + + override protected def doExecute(): RDD[InternalRow] = { + if (shuffledRdd != null) { + shuffledRdd + } else { + inputPlan.execute() + } + } + + override def simpleString: String = "FragmentInput" + + override def innerChildren: Seq[SparkPlan] = inputPlan :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala new file mode 100644 index 0000000000000..86580be735eb9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala @@ -0,0 +1,345 @@ +package org.apache.spark.sql.execution.adaptive + +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{LinkedBlockingDeque, BlockingQueue} +import java.util.{HashMap => JHashMap, Map => JMap} + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{MapOutputStatistics, SimpleFutureAction, ShuffleDependency} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.types.LongType + +/** + * A physical plan tree is divided into a DAG tree of QueryFragment. + * An QueryFragment is a basic execution unit that could be optimized + * According to statistics of children fragments. + */ +trait QueryFragment extends SparkPlan { + + def children: Seq[QueryFragment] + def isRoot: Boolean + def id: Long + + private[this] var exchange: ShuffleExchange = null + + protected[this] var rootPlan: SparkPlan = null + + private[this] var fragmentInput: FragmentInput = null + + private[this] var parentFragment: QueryFragment = null + + var nextChildIndex: Int = 0 + + private[this] val fragmentsIndex: JMap[QueryFragment, Integer] = + new JHashMap[QueryFragment, Integer](children.size) + + private[this] val shuffleDependencies = + new Array[ShuffleDependency[Int, InternalRow, InternalRow]](children.size) + + private[this] val mapOutputStatistics = new Array[MapOutputStatistics](children.size) + + private[this] val advisoryTargetPostShuffleInputSize: Long = + sqlContext.conf.targetPostShuffleInputSize + + override def output: Seq[Attribute] = executedPlan.output + + private[this] def executedPlan: SparkPlan = if (isRoot) { + rootPlan + } else { + exchange + } + + private[sql] def setParentFragment(fragment: QueryFragment) = { + this.parentFragment = fragment + } + + private[sql] def getParentFragment() = parentFragment + + private[sql] def setFragmentInput(fragmentInput: FragmentInput) = { + this.fragmentInput = fragmentInput + } + + private[sql] def getFragmentInput() = fragmentInput + + private[sql] def setExchange(exchange: ShuffleExchange) = { + this.exchange = exchange + } + + private[sql] def getExchange(): ShuffleExchange = exchange + + private[sql] def setRootPlan(root: SparkPlan) = { + this.rootPlan = root + } + + protected[sql] def isAvailable: Boolean = nextChildIndex >= children.size + + + protected def doExecute(): RDD[InternalRow] = null + + protected[sql] def adaptiveExecute(): (ShuffleDependency[Int, InternalRow, InternalRow], + SimpleFutureAction[MapOutputStatistics]) = synchronized { + val executedPlan = sqlContext.sparkSession.sessionState.codegenForExecution(exchange) + .asInstanceOf[ShuffleExchange] + logInfo(s"== Submit Query Fragment ${id} Physical plan ==") + logInfo(stringOrError(executedPlan.toString)) + val shuffleDependency = executedPlan.prepareShuffleDependency() + if (shuffleDependency.rdd.partitions.length != 0) { + val futureAction: SimpleFutureAction[MapOutputStatistics] = + sqlContext.sparkContext.submitMapStage[Int, InternalRow, InternalRow](shuffleDependency) + (shuffleDependency, futureAction) + } else { + (shuffleDependency, null) + } + } + + protected[sql] def stageFailed(exception: Throwable): Unit = synchronized { + this.parentFragment.stageFailed(exception) + } + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + protected[sql] def setChildCompleted( + child: QueryFragment, + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + statistics: MapOutputStatistics): Unit = synchronized { + fragmentsIndex.put(child, this.nextChildIndex) + shuffleDependencies(this.nextChildIndex) = shuffleDependency + mapOutputStatistics(this.nextChildIndex) = statistics + this.nextChildIndex += 1 + } + + protected[sql] def optimizeOperator(): Unit = synchronized { + val executedPlan = if (isRoot) { + rootPlan + } else { + exchange + } + // Optimize plan + val optimizedPlan = executedPlan.transformDown { + case operator @ SortMergeJoinExec(leftKeys, rightKeys, _, _, + left@SortExec(_, _, _, _), right@SortExec(_, _, _, _)) => { + logInfo("Begin optimize join, operator =\n" + operator.toString) + val newOperator = optimizeJoin(operator, left, right) + logInfo("After optimize join, operator =\n" + newOperator.toString) + newOperator + } + + case agg @ TungstenAggregate(_, _, _, _, _, _, input @ FragmentInput(_)) + if (!input.isOptimized())=> { + optimizeAggregate(agg, input) + } + + case operator: SparkPlan => operator + } + + if (isRoot) { + rootPlan = optimizedPlan + } else { + exchange = optimizedPlan.asInstanceOf[ShuffleExchange] + } + } + + private[this] def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } + + private[this] def optimizeAggregate(agg: SparkPlan, input: FragmentInput): SparkPlan = { + val childFragments = Seq(input.childFragment) + val aggStatistics = new ArrayBuffer[MapOutputStatistics]() + var i = 0 + while(i < childFragments.length) { + val statistics = mapOutputStatistics(fragmentsIndex.get(childFragments(i))) + if (statistics != null) { + aggStatistics += statistics + } + i += 1 + } + val partitionStartIndices = + if (aggStatistics.length == 0) { + None + } else { + Utils.estimatePartitionStartIndices(aggStatistics.toArray, minNumPostShufflePartitions, + advisoryTargetPostShuffleInputSize) + } + val shuffledRowRdd= childFragments(0).getExchange().preparePostShuffleRDD( + shuffleDependencies(fragmentsIndex.get(childFragments(0))), partitionStartIndices) + childFragments(0).getFragmentInput().setShuffleRdd(shuffledRowRdd) + childFragments(0).getFragmentInput().setOptimized() + agg + } + + private[this] def optimizeJoin(joinPlan: SortMergeJoinExec, left: SortExec, right: SortExec) + : SparkPlan = { + // TODO Optimize skew join + val childFragments = Utils.findChildFragment(joinPlan) + assert(childFragments.length == 2) + val joinStatistics = new ArrayBuffer[MapOutputStatistics]() + val childSizeInBytes = new Array[Long](childFragments.length) + var i = 0 + while(i < childFragments.length) { + val statistics = mapOutputStatistics(fragmentsIndex.get(childFragments(i))) + if (statistics != null) { + joinStatistics += statistics + childSizeInBytes(i) = statistics.bytesByPartitionId.sum + } else { + childSizeInBytes(i) = 0 + } + i += 1 + } + val partitionStartIndices = + if (joinStatistics.length == 0) { + None + } else { + Utils.estimatePartitionStartIndices(joinStatistics.toArray, minNumPostShufflePartitions, + advisoryTargetPostShuffleInputSize) + } + + val leftFragment = childFragments(0) + val rightFragment = childFragments(1) + val leftShuffledRowRdd= leftFragment.getExchange().preparePostShuffleRDD( + shuffleDependencies(fragmentsIndex.get(leftFragment)), partitionStartIndices) + val rightShuffledRowRdd = rightFragment.getExchange().preparePostShuffleRDD( + shuffleDependencies(fragmentsIndex.get(rightFragment)), partitionStartIndices) + + leftFragment.getFragmentInput().setShuffleRdd(leftShuffledRowRdd) + leftFragment.getFragmentInput().setOptimized() + rightFragment.getFragmentInput().setShuffleRdd(rightShuffledRowRdd) + rightFragment.getFragmentInput().setOptimized() + + var newOperator: SparkPlan = joinPlan + + if (sqlContext.conf.autoBroadcastJoinThreshold > 0) { + val leftSizeInBytes = childSizeInBytes(0) + val rightSizeInBytes = childSizeInBytes(1) + if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + val keys = Utils.rewriteKeyExpr(joinPlan.leftKeys).map( + BindReferences.bindReference(_, left.child.output)) + + newOperator = BroadcastHashJoinExec( + joinPlan.leftKeys, joinPlan.rightKeys, joinPlan.joinType, BuildLeft, joinPlan.condition, + BroadcastExchangeExec(HashedRelationBroadcastMode(keys), + left.child), + right.child) + } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + val keys = Utils.rewriteKeyExpr(joinPlan.rightKeys).map( + BindReferences.bindReference(_, right.child.output)) + newOperator = BroadcastHashJoinExec( + joinPlan.leftKeys, joinPlan.rightKeys, joinPlan.joinType, BuildRight, joinPlan.condition, + left.child, + BroadcastExchangeExec(HashedRelationBroadcastMode(keys), + right.child)) + } + } + newOperator + } + + /** Returns a string representation of the nodes in this tree */ + override def treeString: String = + executedPlan.generateTreeString(0, Nil, new StringBuilder).toString + + override def simpleString: String = "QueryFragment" +} + +case class RootQueryFragment ( + children: Seq[QueryFragment], + id: Long, + isRoot: Boolean = false) extends QueryFragment { + + private[this] var isThrowException = false + + private[this] var exception: Throwable = null + + private[this] val stopped = new AtomicBoolean(false) + + protected[sql] override def stageFailed(exception: Throwable): Unit = { + isThrowException = true + this.exception = exception + stopped.set(true) + } + + private val eventQueue: BlockingQueue[QueryFragment] = new LinkedBlockingDeque[QueryFragment]() + + protected def executeFragment(child: QueryFragment) = { + val (shuffleDependency, futureAction) = child.adaptiveExecute() + val parent = child.getParentFragment() + if (futureAction != null) { + futureAction.onComplete { + case scala.util.Success(statistics) => + logInfo(s"Query Fragment ${id} finished") + parent.setChildCompleted(child, shuffleDependency, statistics) + if (parent.isAvailable) { + eventQueue.add(parent) + } + case scala.util.Failure(exception) => + logInfo(s"Query Fragment ${id} failed, exception is ${exception}") + parent.stageFailed(exception) + } + } else { + parent.setChildCompleted(child, shuffleDependency, null) + if (parent.isAvailable) { + eventQueue.add(parent) + } + } + } + + + protected override def doExecute(): RDD[InternalRow] = { + assert(isRoot == true) + isThrowException = false + stopped.set(false) + val children = Utils.findLeafFragment(this) + if (!children.isEmpty) { + children.foreach { child => executeFragment(child) } + } else { + stopped.set(true) + } + + val executeThread = new Thread("Fragment execute") { + setDaemon(true) + + override def run(): Unit = { + while (!stopped.get) { + val fragment = eventQueue.take() + fragment.optimizeOperator() + if (fragment.isInstanceOf[RootQueryFragment]) { + stopped.set(true) + } else { + executeFragment(fragment) + } + + } + } + } + executeThread.start() + executeThread.join() + if (isThrowException) { + assert(this.exception != null) + throw exception + } else { + rootPlan.execute() + } + } + + /** Returns a string representation of the nodes in this tree */ + override def treeString: String = + rootPlan.generateTreeString(0, Nil, new StringBuilder).toString + + override def simpleString: String = "QueryFragment" + +} + +case class UnaryQueryFragment ( + children: Seq[QueryFragment], + id: Long, + isRoot: Boolean = false) extends QueryFragment {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala new file mode 100644 index 0000000000000..82f006c6ba5aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala @@ -0,0 +1,91 @@ +package org.apache.spark.sql.execution.adaptive + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.internal.SQLConf + +/** + * Transform a physical plan tree into an query fragment tree + */ +case class QueryFragmentTransformer(conf: SQLConf, maxIterations: Int = 100) + extends Rule[SparkPlan] { + + private val nextFragmentId = new AtomicLong(0) + + def apply(plan: SparkPlan): SparkPlan = { + val newPlan = plan.transformUp { + case operator: SparkPlan => withQueryFragment(operator) + } + if (newPlan.isInstanceOf[ExecutedCommandExec]) { + newPlan + } else { + val childFragments = Utils.findChildFragment(newPlan) + val newFragment = new RootQueryFragment(childFragments, + nextFragmentId.getAndIncrement(), true) + childFragments.foreach(child => child.setParentFragment(newFragment)) + newFragment.setRootPlan(newPlan) + newFragment + } + } + + private[this] def withQueryFragment(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + val children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) + + val supportsAdaptiveExecution = + if (children.exists(_.isInstanceOf[ShuffleExchange])) { + // Right now, Adaptive execution only support HashPartitionings. + children.forall { + case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withFragments = + if (supportsAdaptiveExecution) { + children.zip(requiredChildDistributions).map { + case (e: ShuffleExchange, _) => + // This child is an Exchange, we need to add the fragment. + val childFragments = Utils.findChildFragment(e) + val newFragment = new UnaryQueryFragment(childFragments, + nextFragmentId.getAndIncrement(), false) + childFragments.foreach(child => child.setParentFragment(newFragment)) + val fragmentInput = FragmentInput(newFragment) + fragmentInput.setInputPlan(e) + newFragment.setExchange(e) + newFragment.setFragmentInput(fragmentInput) + fragmentInput + + case (child, distribution) => + child + } + } else { + children + } + + operator.withNewChildren(withFragments) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala new file mode 100644 index 0000000000000..605f5cb6defd8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala @@ -0,0 +1,153 @@ +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{IntegerType, LongType, IntegralType} +import org.apache.spark.MapOutputStatistics +import org.apache.spark.sql.execution.SparkPlan + +import scala.collection.mutable.{Queue, ArrayBuffer} + +/** + * Utility functions used by the query fragment. + */ +private[sql] object Utils extends Logging { + + private[sql] def findChildFragment(root: SparkPlan): Seq[QueryFragment] = { + val result = new ArrayBuffer[QueryFragment] + val queue = new Queue[SparkPlan] + queue.enqueue(root) + while (queue.nonEmpty) { + val current = queue.dequeue() + if (current.isInstanceOf[FragmentInput]) { + val fragmentInput = current.asInstanceOf[FragmentInput] + result += fragmentInput.childFragment + } else { + current.children.foreach(c => queue.enqueue(c)) + } + } + result + } + + private[sql] def findLeafFragment(root: QueryFragment): Seq[QueryFragment] = { + val result = new ArrayBuffer[QueryFragment] + val queue = new Queue[QueryFragment] + queue.enqueue(root) + while (queue.nonEmpty) { + val current = queue.dequeue() + if (current.children.isEmpty) { + result += current + } else { + current.children.foreach(c => queue.enqueue(c)) + } + } + result + } + + /** + * Estimates partition start indices for post-shuffle partitions based on + * mapOutputStatistics provided by all pre-shuffle stages. + */ + private[sql] def estimatePartitionStartIndices( + mapOutputStatistics: Array[MapOutputStatistics], + minNumPostShufflePartitions: Option[Int], + advisoryTargetPostShuffleInputSize: Long): Option[Array[Int]] = { + + // If minNumPostShufflePartitions is defined, it is possible that we need to use a + // value less than advisoryTargetPostShuffleInputSize as the target input size of + // a post shuffle task. + val targetPostShuffleInputSize = minNumPostShufflePartitions match { + case Some(numPartitions) => + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we + // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. + val maxPostShuffleInputSize = + math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) + math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) + + case None => advisoryTargetPostShuffleInputSize + } + + logInfo(s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + + s"targetPostShuffleInputSize $targetPostShuffleInputSize.") + + // Make sure we do get the same number of pre-shuffle partitions for those stages. + val distinctNumPreShufflePartitions = + mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). + assert( + distinctNumPreShufflePartitions.length == 1, + "There should be only one distinct value of the number pre-shuffle partitions " + + "among registered Exchange operator.") + val numPreShufflePartitions = distinctNumPreShufflePartitions.head + + val partitionStartIndices = ArrayBuffer[Int]() + // The first element of partitionStartIndices is always 0. + partitionStartIndices += 0 + + var postShuffleInputSize = 0L + + var i = 0 + while (i < numPreShufflePartitions) { + // We calculate the total size of ith pre-shuffle partitions from all pre-shuffle stages. + // Then, we add the total size to postShuffleInputSize. + var j = 0 + while (j < mapOutputStatistics.length) { + postShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) + j += 1 + } + + // If the current postShuffleInputSize is equal or greater than the + // targetPostShuffleInputSize, We need to add a new element in partitionStartIndices. + if (postShuffleInputSize >= targetPostShuffleInputSize) { + if (i < numPreShufflePartitions - 1) { + // Next start index. + partitionStartIndices += i + 1 + } else { + // This is the last element. So, we do not need to append the next start index to + // partitionStartIndices. + } + // reset postShuffleInputSize. + postShuffleInputSize = 0L + } + + i += 1 + } + + Some(partitionStartIndices.toArray) + } + + private[sql] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + if (e.dataType != LongType) { + keyExpr = Cast(e, LongType) + } else { + keyExpr = e + } + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + width -= bits + } + // TODO: support BooleanType, DateType and TimestampType + case other => + return keys + } + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5e19984debaa7..2ad22c23df0b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -148,6 +148,11 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION2_ENABLED = SQLConfigBuilder("spark.sql.adaptive2.enabled") + .doc("When true, enable adaptive query execution.") + .booleanConf + .createWithDefault(false) + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.adaptive.minNumPostShufflePartitions") .internal() @@ -569,6 +574,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def adaptiveExecution2Enabled: Boolean = getConf(ADAPTIVE_EXECUTION2_ENABLED) + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ebff7569798a6..eca9426348639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.internal import java.io.File import java.util.Properties +import org.apache.spark.sql.catalyst.rules.Rule + import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration @@ -165,6 +167,13 @@ private[sql] class SessionState(sparkSession: SparkSession) { def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) + def codegenForExecution(plan: SparkPlan): SparkPlan = { + codegenRules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + protected def codegenRules: Seq[Rule[SparkPlan]] = + Seq(CollapseCodegenStages(sparkSession.sessionState.conf)) + def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } From 25b00d73325f97e1a4ff5a3795b068387a80a53f Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 May 2016 17:09:59 +0800 Subject: [PATCH 2/7] update minor code --- .../sql/execution/adaptive/QueryFragment.scala | 16 ++++++++++++---- .../apache/spark/sql/internal/SessionState.scala | 5 ++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala index 86580be735eb9..ca3f9c2c1abc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala @@ -10,13 +10,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{MapOutputStatistics, SimpleFutureAction, ShuffleDependency} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.LongType /** * A physical plan tree is divided into a DAG tree of QueryFragment. @@ -136,6 +134,7 @@ trait QueryFragment extends SparkPlan { case agg @ TungstenAggregate(_, _, _, _, _, _, input @ FragmentInput(_)) if (!input.isOptimized())=> { + logInfo("Begin optimize agg, operator =\n" + agg.toString) optimizeAggregate(agg, input) } @@ -262,6 +261,8 @@ case class RootQueryFragment ( private[this] val stopped = new AtomicBoolean(false) + override def nodeName: String = s"RootQueryFragment (fragment id: ${id})" + protected[sql] override def stageFailed(exception: Throwable): Unit = { isThrowException = true this.exception = exception @@ -279,6 +280,7 @@ case class RootQueryFragment ( logInfo(s"Query Fragment ${id} finished") parent.setChildCompleted(child, shuffleDependency, statistics) if (parent.isAvailable) { + logInfo(s"Query Fragment ${parent.id} is available") eventQueue.add(parent) } case scala.util.Failure(exception) => @@ -327,7 +329,10 @@ case class RootQueryFragment ( assert(this.exception != null) throw exception } else { - rootPlan.execute() + logInfo(s"== Submit Query Fragment ${id} Physical plan ==") + val executedPlan = sqlContext.sparkSession.sessionState.codegenForExecution(rootPlan) + logInfo(stringOrError(executedPlan.toString)) + executedPlan.execute() } } @@ -342,4 +347,7 @@ case class RootQueryFragment ( case class UnaryQueryFragment ( children: Seq[QueryFragment], id: Long, - isRoot: Boolean = false) extends QueryFragment {} + isRoot: Boolean = false) extends QueryFragment { + + override def nodeName: String = s"UnaryQueryFragment (fragment id: ${id})" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index eca9426348639..9d6648fe55dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -168,12 +168,11 @@ private[sql] class SessionState(sparkSession: SparkSession) { def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) def codegenForExecution(plan: SparkPlan): SparkPlan = { + val codegenRules: Seq[Rule[SparkPlan]] = + Seq(CollapseCodegenStages(sparkSession.sessionState.conf)) codegenRules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } } - protected def codegenRules: Seq[Rule[SparkPlan]] = - Seq(CollapseCodegenStages(sparkSession.sessionState.conf)) - def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } From 73de7b2d4a185edb22b896464daa4b18ee32914d Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 May 2016 22:18:07 +0800 Subject: [PATCH 3/7] code refactor --- .../execution/adaptive/QueryFragment.scala | 62 +++++++++---------- .../spark/sql/internal/SessionState.scala | 6 -- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala index ca3f9c2c1abc5..8fc74421bb4a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala @@ -4,6 +4,8 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.{LinkedBlockingDeque, BlockingQueue} import java.util.{HashMap => JHashMap, Map => JMap} +import org.apache.spark.sql.catalyst.rules.Rule + import scala.concurrent.ExecutionContext.Implicits.global import scala.collection.mutable.ArrayBuffer @@ -11,7 +13,7 @@ import org.apache.spark.{MapOutputStatistics, SimpleFutureAction, ShuffleDepende import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.{CollapseCodegenStages, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins._ @@ -78,15 +80,19 @@ trait QueryFragment extends SparkPlan { this.rootPlan = root } - protected[sql] def isAvailable: Boolean = nextChildIndex >= children.size + protected[sql] def isAvailable: Boolean = synchronized { + nextChildIndex >= children.size + } + private[sql] def codegenForExecution(plan: SparkPlan): SparkPlan = { + CollapseCodegenStages(sqlContext.conf).apply(plan) + } protected def doExecute(): RDD[InternalRow] = null protected[sql] def adaptiveExecute(): (ShuffleDependency[Int, InternalRow, InternalRow], SimpleFutureAction[MapOutputStatistics]) = synchronized { - val executedPlan = sqlContext.sparkSession.sessionState.codegenForExecution(exchange) - .asInstanceOf[ShuffleExchange] + val executedPlan = codegenForExecution(exchange).asInstanceOf[ShuffleExchange] logInfo(s"== Submit Query Fragment ${id} Physical plan ==") logInfo(stringOrError(executedPlan.toString)) val shuffleDependency = executedPlan.prepareShuffleDependency() @@ -117,15 +123,9 @@ trait QueryFragment extends SparkPlan { } protected[sql] def optimizeOperator(): Unit = synchronized { - val executedPlan = if (isRoot) { - rootPlan - } else { - exchange - } - // Optimize plan val optimizedPlan = executedPlan.transformDown { - case operator @ SortMergeJoinExec(leftKeys, rightKeys, _, _, - left@SortExec(_, _, _, _), right@SortExec(_, _, _, _)) => { + case operator @ SortMergeJoinExec(leftKeys, rightKeys, _, _, left@SortExec(_, _, _, _), + right@SortExec(_, _, _, _)) => { logInfo("Begin optimize join, operator =\n" + operator.toString) val newOperator = optimizeJoin(operator, left, right) logInfo("After optimize join, operator =\n" + newOperator.toString) @@ -133,7 +133,7 @@ trait QueryFragment extends SparkPlan { } case agg @ TungstenAggregate(_, _, _, _, _, _, input @ FragmentInput(_)) - if (!input.isOptimized())=> { + if (!input.isOptimized())=> { logInfo("Begin optimize agg, operator =\n" + agg.toString) optimizeAggregate(agg, input) } @@ -224,20 +224,25 @@ trait QueryFragment extends SparkPlan { if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { val keys = Utils.rewriteKeyExpr(joinPlan.leftKeys).map( BindReferences.bindReference(_, left.child.output)) - newOperator = BroadcastHashJoinExec( - joinPlan.leftKeys, joinPlan.rightKeys, joinPlan.joinType, BuildLeft, joinPlan.condition, - BroadcastExchangeExec(HashedRelationBroadcastMode(keys), - left.child), + joinPlan.leftKeys, + joinPlan.rightKeys, + joinPlan.joinType, + BuildLeft, + joinPlan.condition, + BroadcastExchangeExec(HashedRelationBroadcastMode(keys), left.child), right.child) } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { val keys = Utils.rewriteKeyExpr(joinPlan.rightKeys).map( BindReferences.bindReference(_, right.child.output)) newOperator = BroadcastHashJoinExec( - joinPlan.leftKeys, joinPlan.rightKeys, joinPlan.joinType, BuildRight, joinPlan.condition, + joinPlan.leftKeys, + joinPlan.rightKeys, + joinPlan.joinType, + BuildRight, + joinPlan.condition, left.child, - BroadcastExchangeExec(HashedRelationBroadcastMode(keys), - right.child)) + BroadcastExchangeExec(HashedRelationBroadcastMode(keys), right.child)) } } newOperator @@ -269,9 +274,10 @@ case class RootQueryFragment ( stopped.set(true) } - private val eventQueue: BlockingQueue[QueryFragment] = new LinkedBlockingDeque[QueryFragment]() + private[this] val eventQueue: BlockingQueue[QueryFragment] = + new LinkedBlockingDeque[QueryFragment]() - protected def executeFragment(child: QueryFragment) = { + protected[sql] def executeFragment(child: QueryFragment) = { val (shuffleDependency, futureAction) = child.adaptiveExecute() val parent = child.getParentFragment() if (futureAction != null) { @@ -290,12 +296,12 @@ case class RootQueryFragment ( } else { parent.setChildCompleted(child, shuffleDependency, null) if (parent.isAvailable) { + logInfo(s"Query Fragment ${parent.id} is available") eventQueue.add(parent) } } } - protected override def doExecute(): RDD[InternalRow] = { assert(isRoot == true) isThrowException = false @@ -319,7 +325,6 @@ case class RootQueryFragment ( } else { executeFragment(fragment) } - } } } @@ -330,18 +335,11 @@ case class RootQueryFragment ( throw exception } else { logInfo(s"== Submit Query Fragment ${id} Physical plan ==") - val executedPlan = sqlContext.sparkSession.sessionState.codegenForExecution(rootPlan) + val executedPlan = codegenForExecution(rootPlan) logInfo(stringOrError(executedPlan.toString)) executedPlan.execute() } } - - /** Returns a string representation of the nodes in this tree */ - override def treeString: String = - rootPlan.generateTreeString(0, Nil, new StringBuilder).toString - - override def simpleString: String = "QueryFragment" - } case class UnaryQueryFragment ( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 9d6648fe55dbc..b1039971e6d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -167,12 +167,6 @@ private[sql] class SessionState(sparkSession: SparkSession) { def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) - def codegenForExecution(plan: SparkPlan): SparkPlan = { - val codegenRules: Seq[Rule[SparkPlan]] = - Seq(CollapseCodegenStages(sparkSession.sessionState.conf)) - codegenRules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } - } - def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } From 6cc187eb9a6cbd1b38b63e2850c60507e058fd57 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 May 2016 22:27:43 +0800 Subject: [PATCH 4/7] remove unused import --- .../main/scala/org/apache/spark/sql/internal/SessionState.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index b1039971e6d83..ebff7569798a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.internal import java.io.File import java.util.Properties -import org.apache.spark.sql.catalyst.rules.Rule - import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration From 6f1105cb9bbc1abcc9891d2a9f8fde1bff058cf8 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 May 2016 22:36:38 +0800 Subject: [PATCH 5/7] add apache license --- .../sql/execution/adaptive/FragmentInput.scala | 17 +++++++++++++++++ .../sql/execution/adaptive/QueryFragment.scala | 17 +++++++++++++++++ .../adaptive/QueryFragmentTransformer.scala | 17 +++++++++++++++++ .../spark/sql/execution/adaptive/utils.scala | 17 +++++++++++++++++ 4 files changed, 68 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala index 14c852acc2d26..473083480d90f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution.adaptive import org.apache.spark.rdd.RDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala index 8fc74421bb4a9..1dd6e2b256eed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution.adaptive import java.util.concurrent.atomic.AtomicBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala index 82f006c6ba5aa..fc6fd427a1a2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution.adaptive import java.util.concurrent.atomic.AtomicLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala index 605f5cb6defd8..d63ac53bcf7a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution.adaptive import org.apache.spark.internal.Logging From 5de46b05092e77fc68dda8ef4dc1d549b4ef9373 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Wed, 11 May 2016 23:01:54 +0800 Subject: [PATCH 6/7] add ut --- .../spark/sql/execution/QueryExecution.scala | 2 +- .../execution/adaptive/QueryFragment.scala | 36 ++++++++++++------- .../spark/sql/execution/adaptive/utils.scala | 12 ++++--- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 74ed3338b50b2..523b060b23308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,13 +21,13 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.adaptive.QueryFragmentTransformer import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.adaptive.QueryFragmentTransformer import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala index 1dd6e2b256eed..6505899df23cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala @@ -17,19 +17,18 @@ package org.apache.spark.sql.execution.adaptive -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.{LinkedBlockingDeque, BlockingQueue} import java.util.{HashMap => JHashMap, Map => JMap} +import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} +import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.sql.catalyst.rules.Rule - -import scala.concurrent.ExecutionContext.Implicits.global import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global -import org.apache.spark.{MapOutputStatistics, SimpleFutureAction, ShuffleDependency} +import org.apache.spark.{MapOutputStatistics, ShuffleDependency, SimpleFutureAction} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{CollapseCodegenStages, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} @@ -150,7 +149,7 @@ trait QueryFragment extends SparkPlan { } case agg @ TungstenAggregate(_, _, _, _, _, _, input @ FragmentInput(_)) - if (!input.isOptimized())=> { + if (!input.isOptimized()) => { logInfo("Begin optimize agg, operator =\n" + agg.toString) optimizeAggregate(agg, input) } @@ -188,7 +187,7 @@ trait QueryFragment extends SparkPlan { Utils.estimatePartitionStartIndices(aggStatistics.toArray, minNumPostShufflePartitions, advisoryTargetPostShuffleInputSize) } - val shuffledRowRdd= childFragments(0).getExchange().preparePostShuffleRDD( + val shuffledRowRdd = childFragments(0).getExchange().preparePostShuffleRDD( shuffleDependencies(fragmentsIndex.get(childFragments(0))), partitionStartIndices) childFragments(0).getFragmentInput().setShuffleRdd(shuffledRowRdd) childFragments(0).getFragmentInput().setOptimized() @@ -223,7 +222,7 @@ trait QueryFragment extends SparkPlan { val leftFragment = childFragments(0) val rightFragment = childFragments(1) - val leftShuffledRowRdd= leftFragment.getExchange().preparePostShuffleRDD( + val leftShuffledRowRdd = leftFragment.getExchange().preparePostShuffleRDD( shuffleDependencies(fragmentsIndex.get(leftFragment)), partitionStartIndices) val rightShuffledRowRdd = rightFragment.getExchange().preparePostShuffleRDD( shuffleDependencies(fragmentsIndex.get(rightFragment)), partitionStartIndices) @@ -238,7 +237,17 @@ trait QueryFragment extends SparkPlan { if (sqlContext.conf.autoBroadcastJoinThreshold > 0) { val leftSizeInBytes = childSizeInBytes(0) val rightSizeInBytes = childSizeInBytes(1) - if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + val joinType = joinPlan.joinType + def canBuildLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + def canBuildRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter | LeftSemi | LeftAnti => true + case j: ExistenceJoin => true + case _ => false + } + if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold && canBuildLeft(joinType)) { val keys = Utils.rewriteKeyExpr(joinPlan.leftKeys).map( BindReferences.bindReference(_, left.child.output)) newOperator = BroadcastHashJoinExec( @@ -249,7 +258,8 @@ trait QueryFragment extends SparkPlan { joinPlan.condition, BroadcastExchangeExec(HashedRelationBroadcastMode(keys), left.child), right.child) - } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold + && canBuildRight(joinType)) { val keys = Utils.rewriteKeyExpr(joinPlan.rightKeys).map( BindReferences.bindReference(_, right.child.output)) newOperator = BroadcastHashJoinExec( @@ -260,13 +270,13 @@ trait QueryFragment extends SparkPlan { joinPlan.condition, left.child, BroadcastExchangeExec(HashedRelationBroadcastMode(keys), right.child)) - } + } } newOperator } /** Returns a string representation of the nodes in this tree */ - override def treeString: String = + override def treeString: String = executedPlan.generateTreeString(0, Nil, new StringBuilder).toString override def simpleString: String = "QueryFragment" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala index d63ac53bcf7a0..04931cfdc2b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.adaptive +import scala.collection.mutable.{ArrayBuffer, Queue} + +import org.apache.spark.MapOutputStatistics import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{IntegerType, LongType, IntegralType} -import org.apache.spark.MapOutputStatistics import org.apache.spark.sql.execution.SparkPlan - -import scala.collection.mutable.{Queue, ArrayBuffer} +import org.apache.spark.sql.types.{IntegralType, LongType} /** * Utility functions used by the query fragment. @@ -49,7 +49,9 @@ private[sql] object Utils extends Logging { private[sql] def findLeafFragment(root: QueryFragment): Seq[QueryFragment] = { val result = new ArrayBuffer[QueryFragment] val queue = new Queue[QueryFragment] - queue.enqueue(root) + if (!root.children.isEmpty) { + root.children.foreach(c => queue.enqueue(c)) + } while (queue.nonEmpty) { val current = queue.dequeue() if (current.children.isEmpty) { From a18be55928dcc3f8affef6ee6c5733975f64dba1 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Wed, 11 May 2016 23:30:41 +0800 Subject: [PATCH 7/7] add ut --- .../adaptive/QueryFragmentSuite.scala | 181 ++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala new file mode 100644 index 0000000000000..e671a1002f7e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +class QueryFragmentSuite extends QueryTest with SQLTestUtils with SharedSQLContext { + import testImplicits._ + + setupTestData() + + test("adaptive optimization: transform sort merge join to broadcast join for inner join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2")) + .select(col("key1"), col("cnt1"), col("cnt2")) + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join2 = df3.join(df1, col("key3") === col("key1")) + .select(col("key1"), col("cnt1"), col("cnt3")) + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for outer join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2"), "left_outer") + .select(col("key1"), col("cnt1"), col("cnt2")) + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect()) + + val join2 = df1.join(df2, col("key1") === col("key2"), "right_outer") + .select(col("key1"), col("cnt1"), col("cnt2")) + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join3 = df3.join(df1, col("key3") === col("key1"), "left_outer") + .select(col("key1"), col("cnt1"), col("cnt3")) + checkAnswer(join3, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3") + .union(sqlContext.range(0, 99950).selectExpr("null as key", "null as cnt1", "1 as cnt3")) + .collect()) + + val join4 = df3.join(df1, col("key3") === col("key1"), "right_outer") + .select(col("key1"), col("cnt1"), col("cnt3")) + checkAnswer(join4, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for left semi join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2"), "leftsemi") + .select(col("key1"), col("cnt1")) + + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join2 = df3.join(df1, col("key3") === col("key1"), "leftsemi") + .select(col("key3"), col("cnt3")) + + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key3", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for left anti join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 100 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2"), "leftanti") + .select(col("key1"), col("cnt1")) + checkAnswer(join1, + sqlContext.range(50, 100).selectExpr("id as key", "1000 as cnt1").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join2 = df3.join(df1, col("key3") === col("key1"), "leftanti") + .select(col("key3"), col("cnt3")) + + checkAnswer(join2, + sqlContext.range(100, 100000).selectExpr("id as key3", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for existence join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .registerTempTable("testData") + sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .registerTempTable("testData2") + val join1 = sqlContext.sql("select key1, cnt1 from " + + "(select key1, count(value1) as cnt1 from testData group by key1) t1 " + + "where key1 in (select distinct key2 from testData2)") + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key1", "2000 as cnt1").collect()) + sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .registerTempTable("testData3") + val join2 = sqlContext.sql("select key3, value3 from testData3 " + + "where key3 in (select distinct key2 from testData2)") + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key3", "id as value3").collect()) + } + } +}