-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-13123][SQL] Implement whole state codegen for sort. #11008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,11 @@ | |
|
|
||
| package org.apache.spark.sql.execution | ||
|
|
||
| import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} | ||
| import org.apache.spark.{SparkEnv, TaskContext} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.execution.metric.SQLMetrics | ||
|
|
||
|
|
@@ -37,7 +38,7 @@ case class Sort( | |
| global: Boolean, | ||
| child: SparkPlan, | ||
| testSpillFrequency: Int = 0) | ||
| extends UnaryNode { | ||
| extends UnaryNode with CodegenSupport { | ||
|
|
||
| override def output: Seq[Attribute] = child.output | ||
|
|
||
|
|
@@ -50,34 +51,38 @@ case class Sort( | |
| "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), | ||
| "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) | ||
|
|
||
| def createSorter(): UnsafeExternalRowSorter = { | ||
| val ordering = newOrdering(sortOrder, output) | ||
|
|
||
| // The comparator for comparing prefix | ||
| val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) | ||
| val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) | ||
|
|
||
| // The generator for prefix | ||
| val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) | ||
| val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { | ||
| override def computePrefix(row: InternalRow): Long = { | ||
| prefixProjection.apply(row).getLong(0) | ||
| } | ||
| } | ||
|
|
||
| val pageSize = SparkEnv.get.memoryManager.pageSizeBytes | ||
| val sorter = new UnsafeExternalRowSorter( | ||
| schema, ordering, prefixComparator, prefixComputer, pageSize) | ||
| if (testSpillFrequency > 0) { | ||
| sorter.setTestSpillFrequency(testSpillFrequency) | ||
| } | ||
| sorter | ||
| } | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = { | ||
| val schema = child.schema | ||
| val childOutput = child.output | ||
|
|
||
| val dataSize = longMetric("dataSize") | ||
| val spillSize = longMetric("spillSize") | ||
|
|
||
| child.execute().mapPartitionsInternal { iter => | ||
| val ordering = newOrdering(sortOrder, childOutput) | ||
|
|
||
| // The comparator for comparing prefix | ||
| val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) | ||
| val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) | ||
|
|
||
| // The generator for prefix | ||
| val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) | ||
| val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { | ||
| override def computePrefix(row: InternalRow): Long = { | ||
| prefixProjection.apply(row).getLong(0) | ||
| } | ||
| } | ||
|
|
||
| val pageSize = SparkEnv.get.memoryManager.pageSizeBytes | ||
| val sorter = new UnsafeExternalRowSorter( | ||
| schema, ordering, prefixComparator, prefixComputer, pageSize) | ||
| if (testSpillFrequency > 0) { | ||
| sorter.setTestSpillFrequency(testSpillFrequency) | ||
| } | ||
| val sorter = createSorter() | ||
|
|
||
| val metrics = TaskContext.get().taskMetrics() | ||
| // Remember spill data size of this task before execute this operator so that we can | ||
|
|
@@ -93,4 +98,63 @@ case class Sort( | |
| sortedIterator | ||
| } | ||
| } | ||
|
|
||
| override def upstream(): RDD[InternalRow] = { | ||
| child.asInstanceOf[CodegenSupport].upstream() | ||
| } | ||
|
|
||
| // Name of sorter variable used in codegen. | ||
| private var sorterVariable: String = _ | ||
|
|
||
| override protected def doProduce(ctx: CodegenContext): String = { | ||
| val needToSort = ctx.freshName("needToSort") | ||
| ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") | ||
|
|
||
|
|
||
| // Initialize the class member variables. This includes the instance of the Sorter and | ||
| // the iterator to return sorted rows. | ||
| val thisPlan = ctx.addReferenceObj("plan", this) | ||
| sorterVariable = ctx.freshName("sorter") | ||
| ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, | ||
| s"$sorterVariable = $thisPlan.createSorter();") | ||
| val sortedIterator = ctx.freshName("sortedIter") | ||
| ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "") | ||
|
|
||
| val addToSorter = ctx.freshName("addToSorter") | ||
| ctx.addNewFunction(addToSorter, | ||
| s""" | ||
| | private void $addToSorter() throws java.io.IOException { | ||
| | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
| | } | ||
| """.stripMargin.trim) | ||
|
|
||
| val outputRow = ctx.freshName("outputRow") | ||
| s""" | ||
| | if ($needToSort) { | ||
| | $addToSorter(); | ||
| | $sortedIterator = $sorterVariable.sort(); | ||
| | $needToSort = false; | ||
| | } | ||
| | | ||
| | while ($sortedIterator.hasNext()) { | ||
| | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); | ||
| | ${consume(ctx, null, outputRow)} | ||
| | } | ||
| """.stripMargin.trim | ||
| } | ||
|
|
||
| override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
| val colExprs = child.output.zipWithIndex.map { case (attr, i) => | ||
| BoundReference(i, attr.dataType, attr.nullable) | ||
| } | ||
|
|
||
| ctx.currentVars = input | ||
| val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the child can produce UnsafeRow (for example, Exchange), we should have a way to avoid this unpack and pack again, or we will see regression (generated version slower than non-generated). I think we can pass the variable for input row into |
||
|
|
||
| s""" | ||
| | // Convert the input attributes to an UnsafeRow and add it to the sorter | ||
| | ${code.code} | ||
| | $sorterVariable.insertRow(${code.value}); | ||
| """.stripMargin.trim | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is pretty ghetto... (although i understand maybe it's the simplest way to implement this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? This is the state that needs to be kept between the two member functions in this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's ok here as discussed offline. i just found mutable state in here as a way to pass variable names through pretty brittle. maybe good to have a more general abstraction for this in codegen, but not that big of a deal right now/