Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;

final class UnsafeExternalRowSorter {
public final class UnsafeExternalRowSorter {

/**
* If positive, forces records to be spilled to disk at the given frequency (measured in numbers
Expand Down Expand Up @@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) {
testSpillFrequency = frequency;
}

@VisibleForTesting
void insertRow(UnsafeRow row) throws IOException {
public void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
Expand All @@ -110,8 +109,7 @@ private void cleanupResources() {
sorter.cleanupResources();
}

@VisibleForTesting
Iterator<UnsafeRow> sort() throws IOException {
public Iterator<UnsafeRow> sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
Expand Down Expand Up @@ -160,7 +158,6 @@ public UnsafeRow next() {
}
}


public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
while (inputIterator.hasNext()) {
insertRow(inputIterator.next());
Expand Down
110 changes: 87 additions & 23 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.execution

import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics

Expand All @@ -37,7 +38,7 @@ case class Sort(
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
extends UnaryNode {
extends UnaryNode with CodegenSupport {

override def output: Seq[Attribute] = child.output

Expand All @@ -50,34 +51,38 @@ case class Sort(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))

def createSorter(): UnsafeExternalRowSorter = {
val ordering = newOrdering(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

// The generator for prefix
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
override def computePrefix(row: InternalRow): Long = {
prefixProjection.apply(row).getLong(0)
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
sorter
}

protected override def doExecute(): RDD[InternalRow] = {
val schema = child.schema
val childOutput = child.output

val dataSize = longMetric("dataSize")
val spillSize = longMetric("spillSize")

child.execute().mapPartitionsInternal { iter =>
val ordering = newOrdering(sortOrder, childOutput)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

// The generator for prefix
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
override def computePrefix(row: InternalRow): Long = {
prefixProjection.apply(row).getLong(0)
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
val sorter = createSorter()

val metrics = TaskContext.get().taskMetrics()
// Remember spill data size of this task before execute this operator so that we can
Expand All @@ -93,4 +98,63 @@ case class Sort(
sortedIterator
}
}

override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream()
}

// Name of sorter variable used in codegen.
private var sorterVariable: String = _
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty ghetto... (although i understand maybe it's the simplest way to implement this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? This is the state that needs to be kept between the two member functions in this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's ok here as discussed offline. i just found mutable state in here as a way to pass variable names through pretty brittle. maybe good to have a more general abstraction for this in codegen, but not that big of a deal right now/


override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")


// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
sorterVariable = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
s"$sorterVariable = $thisPlan.createSorter();")
val sortedIterator = ctx.freshName("sortedIter")
ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "")

val addToSorter = ctx.freshName("addToSorter")
ctx.addNewFunction(addToSorter,
s"""
| private void $addToSorter() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin.trim)

val outputRow = ctx.freshName("outputRow")
s"""
| if ($needToSort) {
| $addToSorter();
| $sortedIterator = $sorterVariable.sort();
| $needToSort = false;
| }
|
| while ($sortedIterator.hasNext()) {
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
| ${consume(ctx, null, outputRow)}
| }
""".stripMargin.trim
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}

ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the child can produce UnsafeRow (for example, Exchange), we should have a way to avoid this unpack and pack again, or we will see regression (generated version slower than non-generated).

I think we can pass the variable for input row into doCosume, could be null. It's better to do this after #11274 , then we don't need to worry about whether should we create variables for input or not.


s"""
| // Convert the input attributes to an UnsafeRow and add it to the sorter
| ${code.code}
| $sorterVariable.insertRow(${code.value});
""".stripMargin.trim
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])

private Object[] references;
${ctx.declareMutableStates()}

${ctx.declareAddedFunctions()}

public GeneratedIterator(Object[] references) {
Expand All @@ -240,7 +241,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
$code
}
}
"""
""".trim

// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
Expand Down Expand Up @@ -277,7 +278,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
s"""
| currentRow = $row;
| return;
""".stripMargin
""".stripMargin.trim
} else {
assert(input != null)
if (input.nonEmpty) {
Expand All @@ -291,13 +292,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
| ${code.code.trim}
| currentRow = ${code.value};
| return;
""".stripMargin
""".stripMargin.trim
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
""".stripMargin
""".stripMargin.trim
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator

case class TungstenAggregate(
Expand Down Expand Up @@ -287,7 +287,6 @@ case class TungstenAggregate(
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}


/**
* Update peak execution memory, called in generated Java class.
*/
Expand Down