-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-13123][SQL] Implement whole state codegen for sort #11359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
11e26c9
564a5b3
7f50b6a
d50ca8e
aceab91
fa7c991
02aa3d0
c953a60
5674226
097376d
4651ce9
65ed647
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,12 @@ | |
|
|
||
| package org.apache.spark.sql.execution | ||
|
|
||
| import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} | ||
| import org.apache.spark.{SparkEnv, TaskContext} | ||
| import org.apache.spark.executor.TaskMetrics | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.execution.metric.SQLMetrics | ||
|
|
||
|
|
@@ -37,7 +39,7 @@ case class Sort( | |
| global: Boolean, | ||
| child: SparkPlan, | ||
| testSpillFrequency: Int = 0) | ||
| extends UnaryNode { | ||
| extends UnaryNode with CodegenSupport { | ||
|
|
||
| override def output: Seq[Attribute] = child.output | ||
|
|
||
|
|
@@ -50,34 +52,36 @@ case class Sort( | |
| "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), | ||
| "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = { | ||
| val schema = child.schema | ||
| val childOutput = child.output | ||
| def createSorter(): UnsafeExternalRowSorter = { | ||
| val ordering = newOrdering(sortOrder, output) | ||
|
|
||
| // The comparator for comparing prefix | ||
| val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) | ||
| val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) | ||
|
|
||
| // The generator for prefix | ||
| val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) | ||
| val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { | ||
| override def computePrefix(row: InternalRow): Long = { | ||
| prefixProjection.apply(row).getLong(0) | ||
| } | ||
| } | ||
|
|
||
| val pageSize = SparkEnv.get.memoryManager.pageSizeBytes | ||
| val sorter = new UnsafeExternalRowSorter( | ||
| schema, ordering, prefixComparator, prefixComputer, pageSize) | ||
| if (testSpillFrequency > 0) { | ||
| sorter.setTestSpillFrequency(testSpillFrequency) | ||
| } | ||
| sorter | ||
| } | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = { | ||
| val dataSize = longMetric("dataSize") | ||
| val spillSize = longMetric("spillSize") | ||
|
|
||
| child.execute().mapPartitionsInternal { iter => | ||
| val ordering = newOrdering(sortOrder, childOutput) | ||
|
|
||
| // The comparator for comparing prefix | ||
| val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) | ||
| val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) | ||
|
|
||
| // The generator for prefix | ||
| val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) | ||
| val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { | ||
| override def computePrefix(row: InternalRow): Long = { | ||
| prefixProjection.apply(row).getLong(0) | ||
| } | ||
| } | ||
|
|
||
| val pageSize = SparkEnv.get.memoryManager.pageSizeBytes | ||
| val sorter = new UnsafeExternalRowSorter( | ||
| schema, ordering, prefixComparator, prefixComputer, pageSize) | ||
| if (testSpillFrequency > 0) { | ||
| sorter.setTestSpillFrequency(testSpillFrequency) | ||
| } | ||
| val sorter = createSorter() | ||
|
|
||
| val metrics = TaskContext.get().taskMetrics() | ||
| // Remember spill data size of this task before execute this operator so that we can | ||
|
|
@@ -93,4 +97,74 @@ case class Sort( | |
| sortedIterator | ||
| } | ||
| } | ||
|
|
||
| override def upstreams(): Seq[RDD[InternalRow]] = { | ||
| child.asInstanceOf[CodegenSupport].upstreams() | ||
| } | ||
|
|
||
| // Name of sorter variable used in codegen. | ||
| private var sorterVariable: String = _ | ||
|
|
||
| override protected def doProduce(ctx: CodegenContext): String = { | ||
| val needToSort = ctx.freshName("needToSort") | ||
| ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") | ||
|
|
||
|
|
||
| // Initialize the class member variables. This includes the instance of the Sorter and | ||
| // the iterator to return sorted rows. | ||
| val thisPlan = ctx.addReferenceObj("plan", this) | ||
| sorterVariable = ctx.freshName("sorter") | ||
| ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, | ||
| s"$sorterVariable = $thisPlan.createSorter();") | ||
| val metrics = ctx.freshName("metrics") | ||
| ctx.addMutableState(classOf[TaskMetrics].getName, metrics, | ||
| s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") | ||
| val sortedIterator = ctx.freshName("sortedIter") | ||
| ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "") | ||
|
|
||
| val addToSorter = ctx.freshName("addToSorter") | ||
| ctx.addNewFunction(addToSorter, | ||
| s""" | ||
| | private void $addToSorter() throws java.io.IOException { | ||
| | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
| | } | ||
| """.stripMargin.trim) | ||
|
|
||
| val outputRow = ctx.freshName("outputRow") | ||
| val dataSize = metricTerm(ctx, "dataSize") | ||
| val spillSize = metricTerm(ctx, "spillSize") | ||
| val spillSizeBefore = ctx.freshName("spillSizeBefore") | ||
| s""" | ||
| | if ($needToSort) { | ||
| | $addToSorter(); | ||
| | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); | ||
| | $sortedIterator = $sorterVariable.sort(); | ||
| | $dataSize.add($sorterVariable.getPeakMemoryUsage()); | ||
| | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); | ||
| | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); | ||
| | $needToSort = false; | ||
| | } | ||
| | | ||
| | while ($sortedIterator.hasNext()) { | ||
| | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); | ||
| | ${consume(ctx, null, outputRow)} | ||
| | if (shouldStop()) return; | ||
| | } | ||
| """.stripMargin.trim | ||
| } | ||
|
|
||
| override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
| val colExprs = child.output.zipWithIndex.map { case (attr, i) => | ||
| BoundReference(i, attr.dataType, attr.nullable) | ||
| } | ||
|
|
||
| ctx.currentVars = input | ||
| val code = GenerateUnsafeProjection.createCode(ctx, colExprs) | ||
|
|
||
| s""" | ||
| | // Convert the input attributes to an UnsafeRow and add it to the sorter | ||
| | ${code.code} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may have performance regression, when Sort is top of Exchange (or other operator that produce UnsafeRow), we will create variables from UnsafeRow, than create another UnsafeRow using these variables. See #11008 (comment) @yhuai Should we revert this patch or fix this by follow-up PR? |
||
| | $sorterVariable.insertRow(${code.value}); | ||
| """.stripMargin.trim | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be a local var. Just remove the ".addMutableState" below and fix line 141.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.