Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;

final class UnsafeExternalRowSorter {
public final class UnsafeExternalRowSorter {

/**
* If positive, forces records to be spilled to disk at the given frequency (measured in numbers
Expand Down Expand Up @@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) {
testSpillFrequency = frequency;
}

@VisibleForTesting
void insertRow(UnsafeRow row) throws IOException {
public void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
Expand All @@ -110,8 +109,7 @@ private void cleanupResources() {
sorter.cleanupResources();
}

@VisibleForTesting
Iterator<UnsafeRow> sort() throws IOException {
public Iterator<UnsafeRow> sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
Expand Down Expand Up @@ -160,7 +158,6 @@ public UnsafeRow next() {
}
}


public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
while (inputIterator.hasNext()) {
insertRow(inputIterator.next());
Expand Down
124 changes: 99 additions & 25 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.execution

import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics

Expand All @@ -37,7 +39,7 @@ case class Sort(
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
extends UnaryNode {
extends UnaryNode with CodegenSupport {

override def output: Seq[Attribute] = child.output

Expand All @@ -50,34 +52,36 @@ case class Sort(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))

protected override def doExecute(): RDD[InternalRow] = {
val schema = child.schema
val childOutput = child.output
def createSorter(): UnsafeExternalRowSorter = {
val ordering = newOrdering(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

// The generator for prefix
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
override def computePrefix(row: InternalRow): Long = {
prefixProjection.apply(row).getLong(0)
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
sorter
}

protected override def doExecute(): RDD[InternalRow] = {
val dataSize = longMetric("dataSize")
val spillSize = longMetric("spillSize")

child.execute().mapPartitionsInternal { iter =>
val ordering = newOrdering(sortOrder, childOutput)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

// The generator for prefix
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
override def computePrefix(row: InternalRow): Long = {
prefixProjection.apply(row).getLong(0)
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
val sorter = createSorter()

val metrics = TaskContext.get().taskMetrics()
// Remember spill data size of this task before execute this operator so that we can
Expand All @@ -93,4 +97,74 @@ case class Sort(
sortedIterator
}
}

override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}

// Name of sorter variable used in codegen.
private var sorterVariable: String = _

override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")


// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
sorterVariable = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
s"$sorterVariable = $thisPlan.createSorter();")
val metrics = ctx.freshName("metrics")
ctx.addMutableState(classOf[TaskMetrics].getName, metrics,
s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();")
val sortedIterator = ctx.freshName("sortedIter")
ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "")

val addToSorter = ctx.freshName("addToSorter")
ctx.addNewFunction(addToSorter,
s"""
| private void $addToSorter() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin.trim)

val outputRow = ctx.freshName("outputRow")
val dataSize = metricTerm(ctx, "dataSize")
val spillSize = metricTerm(ctx, "spillSize")
val spillSizeBefore = ctx.freshName("spillSizeBefore")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be a local var. Just remove the ".addMutableState" below and fix line 141.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

s"""
| if ($needToSort) {
| $addToSorter();
| Long $spillSizeBefore = $metrics.memoryBytesSpilled();
| $sortedIterator = $sorterVariable.sort();
| $dataSize.add($sorterVariable.getPeakMemoryUsage());
| $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
| $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
| $needToSort = false;
| }
|
| while ($sortedIterator.hasNext()) {
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
| ${consume(ctx, null, outputRow)}
| if (shouldStop()) return;
| }
""".stripMargin.trim
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}

ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs)

s"""
| // Convert the input attributes to an UnsafeRow and add it to the sorter
| ${code.code}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may have performance regression, when Sort is top of Exchange (or other operator that produce UnsafeRow), we will create variables from UnsafeRow, than create another UnsafeRow using these variables.

See #11008 (comment)

@yhuai Should we revert this patch or fix this by follow-up PR?

| $sorterVariable.insertRow(${code.value});
""".stripMargin.trim
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
${code.trim}
}
}
"""
""".trim

// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
Expand Down Expand Up @@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// There is an UnsafeRow already
s"""
|append($row.copy());
""".stripMargin
""".stripMargin.trim
} else {
assert(input != null)
if (input.nonEmpty) {
Expand All @@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
s"""
|${code.code.trim}
|append(${code.value}.copy());
""".stripMargin
""".stripMargin.trim
} else {
// There is no columns
s"""
|append(unsafeRow);
""".stripMargin
""".stripMargin.trim
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}

test("Sort should be included in WholeStageCodegen") {
val df = sqlContext.range(3, 0, -1).sort(col("id"))
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}

test("Sort metrics") {
// Assume the execution plan is
// WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
val df = sqlContext.range(10).sort('id)
testSparkPlanMetrics(df, 2, Map.empty)
}

test("SortMergeJoin metrics") {
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
// test should use the deterministic number of partitions.
Expand Down