diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index cb3c46a98bfb4..a5fdc93e0595e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.adaptive.QueryFragmentTransformer import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf @@ -79,7 +80,12 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = + if (sparkSession.sessionState.conf.adaptiveExecution2Enabled) { + prepareForAdaptiveExecution(sparkPlan) + } else { + prepareForExecution(sparkPlan) + } /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() @@ -100,6 +106,17 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf)) + protected def prepareForAdaptiveExecution(plan: SparkPlan): SparkPlan = { + preparationsForAdaptive.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + protected def preparationsForAdaptive: Seq[Rule[SparkPlan]] = Seq( + python.ExtractPythonUDFs, + PlanSubqueries(sparkSession), + EnsureRequirements(sparkSession.sessionState.conf), + ReuseExchange(sparkSession.sessionState.conf), + QueryFragmentTransformer(sparkSession.sessionState.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala new file mode 100644 index 0000000000000..473083480d90f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/FragmentInput.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{LeafExecNode, ShuffledRowRDD, SparkPlan} + +/** + * FragmentInput is the leaf node of parent fragment that connect with an child fragment. + */ +case class FragmentInput(@transient childFragment: QueryFragment) extends LeafExecNode { + + private[this] var optimized: Boolean = false + + private[this] var inputPlan: SparkPlan = null + + private[this] var shuffledRdd: ShuffledRowRDD = null + + override def output: Seq[Attribute] = inputPlan.output + + private[sql] def setOptimized() = { + this.optimized = true + } + + private[sql] def isOptimized(): Boolean = this.optimized + + private[sql] def setShuffleRdd(shuffledRdd: ShuffledRowRDD) = { + this.shuffledRdd = shuffledRdd + } + + private[sql] def setInputPlan(inputPlan: SparkPlan) = { + this.inputPlan = inputPlan + } + + override protected def doExecute(): RDD[InternalRow] = { + if (shuffledRdd != null) { + shuffledRdd + } else { + inputPlan.execute() + } + } + + override def simpleString: String = "FragmentInput" + + override def innerChildren: Seq[SparkPlan] = inputPlan :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala new file mode 100644 index 0000000000000..6505899df23cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import java.util.{HashMap => JHashMap, Map => JMap} +import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.{MapOutputStatistics, ShuffleDependency, SimpleFutureAction} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.{CollapseCodegenStages, SortExec, SparkPlan} +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.joins._ + +/** + * A physical plan tree is divided into a DAG tree of QueryFragment. + * An QueryFragment is a basic execution unit that could be optimized + * According to statistics of children fragments. + */ +trait QueryFragment extends SparkPlan { + + def children: Seq[QueryFragment] + def isRoot: Boolean + def id: Long + + private[this] var exchange: ShuffleExchange = null + + protected[this] var rootPlan: SparkPlan = null + + private[this] var fragmentInput: FragmentInput = null + + private[this] var parentFragment: QueryFragment = null + + var nextChildIndex: Int = 0 + + private[this] val fragmentsIndex: JMap[QueryFragment, Integer] = + new JHashMap[QueryFragment, Integer](children.size) + + private[this] val shuffleDependencies = + new Array[ShuffleDependency[Int, InternalRow, InternalRow]](children.size) + + private[this] val mapOutputStatistics = new Array[MapOutputStatistics](children.size) + + private[this] val advisoryTargetPostShuffleInputSize: Long = + sqlContext.conf.targetPostShuffleInputSize + + override def output: Seq[Attribute] = executedPlan.output + + private[this] def executedPlan: SparkPlan = if (isRoot) { + rootPlan + } else { + exchange + } + + private[sql] def setParentFragment(fragment: QueryFragment) = { + this.parentFragment = fragment + } + + private[sql] def getParentFragment() = parentFragment + + private[sql] def setFragmentInput(fragmentInput: FragmentInput) = { + this.fragmentInput = fragmentInput + } + + private[sql] def getFragmentInput() = fragmentInput + + private[sql] def setExchange(exchange: ShuffleExchange) = { + this.exchange = exchange + } + + private[sql] def getExchange(): ShuffleExchange = exchange + + private[sql] def setRootPlan(root: SparkPlan) = { + this.rootPlan = root + } + + protected[sql] def isAvailable: Boolean = synchronized { + nextChildIndex >= children.size + } + + private[sql] def codegenForExecution(plan: SparkPlan): SparkPlan = { + CollapseCodegenStages(sqlContext.conf).apply(plan) + } + + protected def doExecute(): RDD[InternalRow] = null + + protected[sql] def adaptiveExecute(): (ShuffleDependency[Int, InternalRow, InternalRow], + SimpleFutureAction[MapOutputStatistics]) = synchronized { + val executedPlan = codegenForExecution(exchange).asInstanceOf[ShuffleExchange] + logInfo(s"== Submit Query Fragment ${id} Physical plan ==") + logInfo(stringOrError(executedPlan.toString)) + val shuffleDependency = executedPlan.prepareShuffleDependency() + if (shuffleDependency.rdd.partitions.length != 0) { + val futureAction: SimpleFutureAction[MapOutputStatistics] = + sqlContext.sparkContext.submitMapStage[Int, InternalRow, InternalRow](shuffleDependency) + (shuffleDependency, futureAction) + } else { + (shuffleDependency, null) + } + } + + protected[sql] def stageFailed(exception: Throwable): Unit = synchronized { + this.parentFragment.stageFailed(exception) + } + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + protected[sql] def setChildCompleted( + child: QueryFragment, + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + statistics: MapOutputStatistics): Unit = synchronized { + fragmentsIndex.put(child, this.nextChildIndex) + shuffleDependencies(this.nextChildIndex) = shuffleDependency + mapOutputStatistics(this.nextChildIndex) = statistics + this.nextChildIndex += 1 + } + + protected[sql] def optimizeOperator(): Unit = synchronized { + val optimizedPlan = executedPlan.transformDown { + case operator @ SortMergeJoinExec(leftKeys, rightKeys, _, _, left@SortExec(_, _, _, _), + right@SortExec(_, _, _, _)) => { + logInfo("Begin optimize join, operator =\n" + operator.toString) + val newOperator = optimizeJoin(operator, left, right) + logInfo("After optimize join, operator =\n" + newOperator.toString) + newOperator + } + + case agg @ TungstenAggregate(_, _, _, _, _, _, input @ FragmentInput(_)) + if (!input.isOptimized()) => { + logInfo("Begin optimize agg, operator =\n" + agg.toString) + optimizeAggregate(agg, input) + } + + case operator: SparkPlan => operator + } + + if (isRoot) { + rootPlan = optimizedPlan + } else { + exchange = optimizedPlan.asInstanceOf[ShuffleExchange] + } + } + + private[this] def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } + + private[this] def optimizeAggregate(agg: SparkPlan, input: FragmentInput): SparkPlan = { + val childFragments = Seq(input.childFragment) + val aggStatistics = new ArrayBuffer[MapOutputStatistics]() + var i = 0 + while(i < childFragments.length) { + val statistics = mapOutputStatistics(fragmentsIndex.get(childFragments(i))) + if (statistics != null) { + aggStatistics += statistics + } + i += 1 + } + val partitionStartIndices = + if (aggStatistics.length == 0) { + None + } else { + Utils.estimatePartitionStartIndices(aggStatistics.toArray, minNumPostShufflePartitions, + advisoryTargetPostShuffleInputSize) + } + val shuffledRowRdd = childFragments(0).getExchange().preparePostShuffleRDD( + shuffleDependencies(fragmentsIndex.get(childFragments(0))), partitionStartIndices) + childFragments(0).getFragmentInput().setShuffleRdd(shuffledRowRdd) + childFragments(0).getFragmentInput().setOptimized() + agg + } + + private[this] def optimizeJoin(joinPlan: SortMergeJoinExec, left: SortExec, right: SortExec) + : SparkPlan = { + // TODO Optimize skew join + val childFragments = Utils.findChildFragment(joinPlan) + assert(childFragments.length == 2) + val joinStatistics = new ArrayBuffer[MapOutputStatistics]() + val childSizeInBytes = new Array[Long](childFragments.length) + var i = 0 + while(i < childFragments.length) { + val statistics = mapOutputStatistics(fragmentsIndex.get(childFragments(i))) + if (statistics != null) { + joinStatistics += statistics + childSizeInBytes(i) = statistics.bytesByPartitionId.sum + } else { + childSizeInBytes(i) = 0 + } + i += 1 + } + val partitionStartIndices = + if (joinStatistics.length == 0) { + None + } else { + Utils.estimatePartitionStartIndices(joinStatistics.toArray, minNumPostShufflePartitions, + advisoryTargetPostShuffleInputSize) + } + + val leftFragment = childFragments(0) + val rightFragment = childFragments(1) + val leftShuffledRowRdd = leftFragment.getExchange().preparePostShuffleRDD( + shuffleDependencies(fragmentsIndex.get(leftFragment)), partitionStartIndices) + val rightShuffledRowRdd = rightFragment.getExchange().preparePostShuffleRDD( + shuffleDependencies(fragmentsIndex.get(rightFragment)), partitionStartIndices) + + leftFragment.getFragmentInput().setShuffleRdd(leftShuffledRowRdd) + leftFragment.getFragmentInput().setOptimized() + rightFragment.getFragmentInput().setShuffleRdd(rightShuffledRowRdd) + rightFragment.getFragmentInput().setOptimized() + + var newOperator: SparkPlan = joinPlan + + if (sqlContext.conf.autoBroadcastJoinThreshold > 0) { + val leftSizeInBytes = childSizeInBytes(0) + val rightSizeInBytes = childSizeInBytes(1) + val joinType = joinPlan.joinType + def canBuildLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + def canBuildRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter | LeftSemi | LeftAnti => true + case j: ExistenceJoin => true + case _ => false + } + if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold && canBuildLeft(joinType)) { + val keys = Utils.rewriteKeyExpr(joinPlan.leftKeys).map( + BindReferences.bindReference(_, left.child.output)) + newOperator = BroadcastHashJoinExec( + joinPlan.leftKeys, + joinPlan.rightKeys, + joinPlan.joinType, + BuildLeft, + joinPlan.condition, + BroadcastExchangeExec(HashedRelationBroadcastMode(keys), left.child), + right.child) + } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold + && canBuildRight(joinType)) { + val keys = Utils.rewriteKeyExpr(joinPlan.rightKeys).map( + BindReferences.bindReference(_, right.child.output)) + newOperator = BroadcastHashJoinExec( + joinPlan.leftKeys, + joinPlan.rightKeys, + joinPlan.joinType, + BuildRight, + joinPlan.condition, + left.child, + BroadcastExchangeExec(HashedRelationBroadcastMode(keys), right.child)) + } + } + newOperator + } + + /** Returns a string representation of the nodes in this tree */ + override def treeString: String = + executedPlan.generateTreeString(0, Nil, new StringBuilder).toString + + override def simpleString: String = "QueryFragment" +} + +case class RootQueryFragment ( + children: Seq[QueryFragment], + id: Long, + isRoot: Boolean = false) extends QueryFragment { + + private[this] var isThrowException = false + + private[this] var exception: Throwable = null + + private[this] val stopped = new AtomicBoolean(false) + + override def nodeName: String = s"RootQueryFragment (fragment id: ${id})" + + protected[sql] override def stageFailed(exception: Throwable): Unit = { + isThrowException = true + this.exception = exception + stopped.set(true) + } + + private[this] val eventQueue: BlockingQueue[QueryFragment] = + new LinkedBlockingDeque[QueryFragment]() + + protected[sql] def executeFragment(child: QueryFragment) = { + val (shuffleDependency, futureAction) = child.adaptiveExecute() + val parent = child.getParentFragment() + if (futureAction != null) { + futureAction.onComplete { + case scala.util.Success(statistics) => + logInfo(s"Query Fragment ${id} finished") + parent.setChildCompleted(child, shuffleDependency, statistics) + if (parent.isAvailable) { + logInfo(s"Query Fragment ${parent.id} is available") + eventQueue.add(parent) + } + case scala.util.Failure(exception) => + logInfo(s"Query Fragment ${id} failed, exception is ${exception}") + parent.stageFailed(exception) + } + } else { + parent.setChildCompleted(child, shuffleDependency, null) + if (parent.isAvailable) { + logInfo(s"Query Fragment ${parent.id} is available") + eventQueue.add(parent) + } + } + } + + protected override def doExecute(): RDD[InternalRow] = { + assert(isRoot == true) + isThrowException = false + stopped.set(false) + val children = Utils.findLeafFragment(this) + if (!children.isEmpty) { + children.foreach { child => executeFragment(child) } + } else { + stopped.set(true) + } + + val executeThread = new Thread("Fragment execute") { + setDaemon(true) + + override def run(): Unit = { + while (!stopped.get) { + val fragment = eventQueue.take() + fragment.optimizeOperator() + if (fragment.isInstanceOf[RootQueryFragment]) { + stopped.set(true) + } else { + executeFragment(fragment) + } + } + } + } + executeThread.start() + executeThread.join() + if (isThrowException) { + assert(this.exception != null) + throw exception + } else { + logInfo(s"== Submit Query Fragment ${id} Physical plan ==") + val executedPlan = codegenForExecution(rootPlan) + logInfo(stringOrError(executedPlan.toString)) + executedPlan.execute() + } + } +} + +case class UnaryQueryFragment ( + children: Seq[QueryFragment], + id: Long, + isRoot: Boolean = false) extends QueryFragment { + + override def nodeName: String = s"UnaryQueryFragment (fragment id: ${id})" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala new file mode 100644 index 0000000000000..fc6fd427a1a2d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentTransformer.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.internal.SQLConf + +/** + * Transform a physical plan tree into an query fragment tree + */ +case class QueryFragmentTransformer(conf: SQLConf, maxIterations: Int = 100) + extends Rule[SparkPlan] { + + private val nextFragmentId = new AtomicLong(0) + + def apply(plan: SparkPlan): SparkPlan = { + val newPlan = plan.transformUp { + case operator: SparkPlan => withQueryFragment(operator) + } + if (newPlan.isInstanceOf[ExecutedCommandExec]) { + newPlan + } else { + val childFragments = Utils.findChildFragment(newPlan) + val newFragment = new RootQueryFragment(childFragments, + nextFragmentId.getAndIncrement(), true) + childFragments.foreach(child => child.setParentFragment(newFragment)) + newFragment.setRootPlan(newPlan) + newFragment + } + } + + private[this] def withQueryFragment(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + val children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) + + val supportsAdaptiveExecution = + if (children.exists(_.isInstanceOf[ShuffleExchange])) { + // Right now, Adaptive execution only support HashPartitionings. + children.forall { + case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withFragments = + if (supportsAdaptiveExecution) { + children.zip(requiredChildDistributions).map { + case (e: ShuffleExchange, _) => + // This child is an Exchange, we need to add the fragment. + val childFragments = Utils.findChildFragment(e) + val newFragment = new UnaryQueryFragment(childFragments, + nextFragmentId.getAndIncrement(), false) + childFragments.foreach(child => child.setParentFragment(newFragment)) + val fragmentInput = FragmentInput(newFragment) + fragmentInput.setInputPlan(e) + newFragment.setExchange(e) + newFragment.setFragmentInput(fragmentInput) + fragmentInput + + case (child, distribution) => + child + } + } else { + children + } + + operator.withNewChildren(withFragments) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala new file mode 100644 index 0000000000000..04931cfdc2b3e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable.{ArrayBuffer, Queue} + +import org.apache.spark.MapOutputStatistics +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{IntegralType, LongType} + +/** + * Utility functions used by the query fragment. + */ +private[sql] object Utils extends Logging { + + private[sql] def findChildFragment(root: SparkPlan): Seq[QueryFragment] = { + val result = new ArrayBuffer[QueryFragment] + val queue = new Queue[SparkPlan] + queue.enqueue(root) + while (queue.nonEmpty) { + val current = queue.dequeue() + if (current.isInstanceOf[FragmentInput]) { + val fragmentInput = current.asInstanceOf[FragmentInput] + result += fragmentInput.childFragment + } else { + current.children.foreach(c => queue.enqueue(c)) + } + } + result + } + + private[sql] def findLeafFragment(root: QueryFragment): Seq[QueryFragment] = { + val result = new ArrayBuffer[QueryFragment] + val queue = new Queue[QueryFragment] + if (!root.children.isEmpty) { + root.children.foreach(c => queue.enqueue(c)) + } + while (queue.nonEmpty) { + val current = queue.dequeue() + if (current.children.isEmpty) { + result += current + } else { + current.children.foreach(c => queue.enqueue(c)) + } + } + result + } + + /** + * Estimates partition start indices for post-shuffle partitions based on + * mapOutputStatistics provided by all pre-shuffle stages. + */ + private[sql] def estimatePartitionStartIndices( + mapOutputStatistics: Array[MapOutputStatistics], + minNumPostShufflePartitions: Option[Int], + advisoryTargetPostShuffleInputSize: Long): Option[Array[Int]] = { + + // If minNumPostShufflePartitions is defined, it is possible that we need to use a + // value less than advisoryTargetPostShuffleInputSize as the target input size of + // a post shuffle task. + val targetPostShuffleInputSize = minNumPostShufflePartitions match { + case Some(numPartitions) => + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we + // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. + val maxPostShuffleInputSize = + math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) + math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) + + case None => advisoryTargetPostShuffleInputSize + } + + logInfo(s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + + s"targetPostShuffleInputSize $targetPostShuffleInputSize.") + + // Make sure we do get the same number of pre-shuffle partitions for those stages. + val distinctNumPreShufflePartitions = + mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). + assert( + distinctNumPreShufflePartitions.length == 1, + "There should be only one distinct value of the number pre-shuffle partitions " + + "among registered Exchange operator.") + val numPreShufflePartitions = distinctNumPreShufflePartitions.head + + val partitionStartIndices = ArrayBuffer[Int]() + // The first element of partitionStartIndices is always 0. + partitionStartIndices += 0 + + var postShuffleInputSize = 0L + + var i = 0 + while (i < numPreShufflePartitions) { + // We calculate the total size of ith pre-shuffle partitions from all pre-shuffle stages. + // Then, we add the total size to postShuffleInputSize. + var j = 0 + while (j < mapOutputStatistics.length) { + postShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) + j += 1 + } + + // If the current postShuffleInputSize is equal or greater than the + // targetPostShuffleInputSize, We need to add a new element in partitionStartIndices. + if (postShuffleInputSize >= targetPostShuffleInputSize) { + if (i < numPreShufflePartitions - 1) { + // Next start index. + partitionStartIndices += i + 1 + } else { + // This is the last element. So, we do not need to append the next start index to + // partitionStartIndices. + } + // reset postShuffleInputSize. + postShuffleInputSize = 0L + } + + i += 1 + } + + Some(partitionStartIndices.toArray) + } + + private[sql] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + if (e.dataType != LongType) { + keyExpr = Cast(e, LongType) + } else { + keyExpr = e + } + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + width -= bits + } + // TODO: support BooleanType, DateType and TimestampType + case other => + return keys + } + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7933d12e284f1..a605fbc00ea64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -148,6 +148,11 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION2_ENABLED = SQLConfigBuilder("spark.sql.adaptive2.enabled") + .doc("When true, enable adaptive query execution.") + .booleanConf + .createWithDefault(false) + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.adaptive.minNumPostShufflePartitions") .internal() @@ -571,6 +576,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def adaptiveExecution2Enabled: Boolean = getConf(ADAPTIVE_EXECUTION2_ENABLED) + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala new file mode 100644 index 0000000000000..e671a1002f7e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +class QueryFragmentSuite extends QueryTest with SQLTestUtils with SharedSQLContext { + import testImplicits._ + + setupTestData() + + test("adaptive optimization: transform sort merge join to broadcast join for inner join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2")) + .select(col("key1"), col("cnt1"), col("cnt2")) + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join2 = df3.join(df1, col("key3") === col("key1")) + .select(col("key1"), col("cnt1"), col("cnt3")) + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for outer join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2"), "left_outer") + .select(col("key1"), col("cnt1"), col("cnt2")) + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect()) + + val join2 = df1.join(df2, col("key1") === col("key2"), "right_outer") + .select(col("key1"), col("cnt1"), col("cnt2")) + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join3 = df3.join(df1, col("key3") === col("key1"), "left_outer") + .select(col("key1"), col("cnt1"), col("cnt3")) + checkAnswer(join3, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3") + .union(sqlContext.range(0, 99950).selectExpr("null as key", "null as cnt1", "1 as cnt3")) + .collect()) + + val join4 = df3.join(df1, col("key3") === col("key1"), "right_outer") + .select(col("key1"), col("cnt1"), col("cnt3")) + checkAnswer(join4, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for left semi join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2"), "leftsemi") + .select(col("key1"), col("cnt1")) + + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join2 = df3.join(df1, col("key3") === col("key1"), "leftsemi") + .select(col("key3"), col("cnt3")) + + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key3", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for left anti join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + val df1 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 100 as key1", "id as value1") + .groupBy("key1") + .agg($"key1", count("value1") as "cnt1") + val df2 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .groupBy("key2") + .agg($"key2", count("value2") as "cnt2") + val join1 = df1.join(df2, col("key1") === col("key2"), "leftanti") + .select(col("key1"), col("cnt1")) + checkAnswer(join1, + sqlContext.range(50, 100).selectExpr("id as key", "1000 as cnt1").collect()) + + val df3 = sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .groupBy("key3") + .agg($"key3", count("value3") as "cnt3") + val join2 = df3.join(df1, col("key3") === col("key1"), "leftanti") + .select(col("key3"), col("cnt3")) + + checkAnswer(join2, + sqlContext.range(100, 100000).selectExpr("id as key3", "1 as cnt3").collect()) + } + } + + test("adaptive optimization: transform sort merge join to broadcast join for existence join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val numInputPartitions: Int = 2 + sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key1", "id as value1") + .registerTempTable("testData") + sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id % 50 as key2", "id as value2") + .registerTempTable("testData2") + val join1 = sqlContext.sql("select key1, cnt1 from " + + "(select key1, count(value1) as cnt1 from testData group by key1) t1 " + + "where key1 in (select distinct key2 from testData2)") + checkAnswer(join1, + sqlContext.range(0, 50).selectExpr("id as key1", "2000 as cnt1").collect()) + sqlContext.range(0, 100000, 1, numInputPartitions) + .selectExpr("id as key3", "id as value3") + .registerTempTable("testData3") + val join2 = sqlContext.sql("select key3, value3 from testData3 " + + "where key3 in (select distinct key2 from testData2)") + checkAnswer(join2, + sqlContext.range(0, 50).selectExpr("id as key3", "id as value3").collect()) + } + } +}