Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public final class UnsafeExternalRowSorter {
private final PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;

// This flag makes sure the cleanupResource() has been called. After the cleanup work,
// iterator.next should always return false. Downstream operator triggers the resource
// cleanup while they found there's no need to keep the iterator any more.
// See more details in SPARK-21492.
private boolean isReleased = false;

public abstract static class PrefixComputer {

public static class Prefix {
Expand Down Expand Up @@ -159,7 +165,8 @@ public long getSortTimeNanos() {
return sorter.getSortTimeNanos();
}

private void cleanupResources() {
public void cleanupResources() {
isReleased = true;
sorter.cleanupResources();
}

Expand All @@ -178,7 +185,7 @@ public Iterator<UnsafeRow> sort() throws IOException {

@Override
public boolean hasNext() {
return sortedIterator.hasNext();
return !isReleased && sortedIterator.hasNext();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ case class SortExec(
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))

private[sql] var rowSorter: UnsafeExternalRowSorter = _

/**
* This method gets invoked only once for each SortExec instance to initialize an
* UnsafeExternalRowSorter, both `plan.execute` and code generation are using it.
* In the code generation code path, we need to call this function outside the class so we
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
val ordering = newOrdering(sortOrder, output)

Expand All @@ -84,13 +92,13 @@ case class SortExec(
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = UnsafeExternalRowSorter.create(
rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)

if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
rowSorter.setTestSpillFrequency(testSpillFrequency)
}
sorter
rowSorter
}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -186,4 +194,17 @@ case class SortExec(
|$sorterVariable.insertRow((UnsafeRow)${row.value});
""".stripMargin
}

/**
* In SortExec, we overwrites cleanupResources to close UnsafeExternalRowSorter.
*/
override protected[sql] def cleanupResources(): Unit = {
if (rowSorter != null) {
// There's possible for rowSorter is null here, for example, in the scenario of empty
// iterator in the current task, the downstream physical node(like SortMergeJoinExec) will
// trigger cleanupResources before rowSorter initialized in createSorter.
rowSorter.cleanupResources()
}
super.cleanupResources()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
newOrdering(order, Seq.empty)
}

/**
* Cleans up the resources used by the physical operator (if any). In general, all the resources
* should be cleaned up when the task finishes but operators like SortMergeJoinExec and LimitExec
* may want eager cleanup to free up tight resources (e.g., memory).
*/
protected[sql] def cleanupResources(): Unit = {
children.foreach(_.cleanupResources())
}
}

object SparkPlan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold
spillThreshold,
cleanupResources
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -217,7 +218,8 @@ case class SortMergeJoinExec(
streamedIter = RowIterator.fromScala(leftIter),
bufferedIter = RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold
spillThreshold,
cleanupResources
)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
Expand All @@ -231,7 +233,8 @@ case class SortMergeJoinExec(
streamedIter = RowIterator.fromScala(rightIter),
bufferedIter = RowIterator.fromScala(leftIter),
inMemoryThreshold,
spillThreshold
spillThreshold,
cleanupResources
)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
Expand Down Expand Up @@ -265,7 +268,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold
spillThreshold,
cleanupResources
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -300,7 +304,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold
spillThreshold,
cleanupResources
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -342,7 +347,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold
spillThreshold,
cleanupResources
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -622,6 +628,9 @@ case class SortMergeJoinExec(
(evaluateVariables(leftVars), "")
}

val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
| ${leftVarDecl.mkString("\n")}
Expand All @@ -635,6 +644,7 @@ case class SortMergeJoinExec(
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}
}
Expand All @@ -660,6 +670,7 @@ case class SortMergeJoinExec(
* @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by
* internal buffer
* @param spillThreshold Threshold for number of rows to be spilled by internal buffer
* @param eagerCleanupResources the eager cleanup function to be invoked when no join row found
*/
private[joins] class SortMergeJoinScanner(
streamedKeyGenerator: Projection,
Expand All @@ -668,7 +679,8 @@ private[joins] class SortMergeJoinScanner(
streamedIter: RowIterator,
bufferedIter: RowIterator,
inMemoryThreshold: Int,
spillThreshold: Int) {
spillThreshold: Int,
eagerCleanupResources: () => Unit) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
Expand All @@ -692,7 +704,8 @@ private[joins] class SortMergeJoinScanner(
def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches

/**
* Advances both input iterators, stopping when we have found rows with matching join keys.
* Advances both input iterators, stopping when we have found rows with matching join keys. If no
* join rows found, try to do the eager resources cleanup.
* @return true if matching rows have been found and false otherwise. If this returns true, then
* [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
* results.
Expand All @@ -702,7 +715,7 @@ private[joins] class SortMergeJoinScanner(
// Advance the streamed side of the join until we find the next row whose join key contains
// no nulls or we hit the end of the streamed iterator.
}
if (streamedRow == null) {
val found = if (streamedRow == null) {
// We have consumed the entire streamed iterator, so there can be no more matches.
matchJoinKey = null
bufferedMatches.clear()
Expand Down Expand Up @@ -742,17 +755,19 @@ private[joins] class SortMergeJoinScanner(
true
}
}
if (!found) eagerCleanupResources()
found
}

/**
* Advances the streamed input iterator and buffers all rows from the buffered input that
* have matching keys.
* have matching keys. If no join rows found, try to do the eager resources cleanup.
* @return true if the streamed iterator returned a row, false otherwise. If this returns true,
* then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer
* join results.
*/
final def findNextOuterJoinRows(): Boolean = {
if (!advancedStreamed()) {
val found = if (!advancedStreamed()) {
// We have consumed the entire streamed iterator, so there can be no more matches.
matchJoinKey = null
bufferedMatches.clear()
Expand Down Expand Up @@ -782,6 +797,8 @@ private[joins] class SortMergeJoinScanner(
// If there is a streamed input then we always return true
true
}
if (!found) eagerCleanupResources()
found
}

// --- Private methods --------------------------------------------------------------------------
Expand Down
33 changes: 32 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.language.existentials

import org.mockito.Mockito._

import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec, SparkPlan}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand All @@ -34,6 +36,23 @@ import org.apache.spark.sql.types.StructType
class JoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._

private def attachCleanupResourceChecker(plan: SparkPlan): Unit = {
// SPARK-21492: Check cleanupResources are finally triggered in SortExec node for every
// test case
plan.foreachUp {
case s: SortExec =>
val sortExec = spy(s)
verify(sortExec, atLeastOnce).cleanupResources()
verify(sortExec.rowSorter, atLeastOnce).cleanupResources()
case _ =>
}
}

override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = {
attachCleanupResourceChecker(df.queryExecution.sparkPlan)
super.checkAnswer(df, rows)
}

setupTestData()

def statisticSizeInByte(df: DataFrame): BigInt = {
Expand Down Expand Up @@ -927,4 +946,16 @@ class JoinSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Row(1, 100, 42, 200, 1, 42))
}
}

test("SPARK-21492: cleanupResource without code generation") {
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.SHUFFLE_PARTITIONS.key -> "1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.range(0, 10, 1, 2)
val df2 = spark.range(10).select($"id".as("b1"), (- $"id").as("b2"))
val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id")
checkAnswer(res, Row(0, 0, 0))
}
}
}