Skip to content

Commit 406eb1c

Browse files
tejasapatilhvanhovell
authored andcommitted
[SPARK-21595] Separate thresholds for buffering and spilling in ExternalAppendOnlyUnsafeRowArray
## What changes were proposed in this pull request? [SPARK-21595](https://issues.apache.org/jira/browse/SPARK-21595) reported that there is excessive spilling to disk due to default spill threshold for `ExternalAppendOnlyUnsafeRowArray` being quite small for WINDOW operator. Old behaviour of WINDOW operator (pre #16909) would hold data in an array for first 4096 records post which it would switch to `UnsafeExternalSorter` and start spilling to disk after reaching `spark.shuffle.spill.numElementsForceSpillThreshold` (or earlier if there was paucity of memory due to excessive consumers). Currently the (switch from in-memory to `UnsafeExternalSorter`) and (`UnsafeExternalSorter` spilling to disk) for `ExternalAppendOnlyUnsafeRowArray` is controlled by a single threshold. This PR aims to separate that to have more granular control. ## How was this patch tested? Added unit tests Author: Tejas Patil <[email protected]> Closes #18843 from tejasapatil/SPARK-21595. (cherry picked from commit 9443999) Signed-off-by: Herman van Hovell <[email protected]>
1 parent c909496 commit 406eb1c

File tree

9 files changed

+155
-70
lines changed

9 files changed

+155
-70
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -774,24 +774,47 @@ object SQLConf {
774774
.stringConf
775775
.createWithDefaultFunction(() => TimeZone.getDefault.getID)
776776

777+
val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
778+
buildConf("spark.sql.windowExec.buffer.in.memory.threshold")
779+
.internal()
780+
.doc("Threshold for number of rows guaranteed to be held in memory by the window operator")
781+
.intConf
782+
.createWithDefault(4096)
783+
777784
val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD =
778785
buildConf("spark.sql.windowExec.buffer.spill.threshold")
779786
.internal()
780-
.doc("Threshold for number of rows buffered in window operator")
787+
.doc("Threshold for number of rows to be spilled by window operator")
781788
.intConf
782-
.createWithDefault(4096)
789+
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
790+
791+
val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
792+
buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
793+
.internal()
794+
.doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " +
795+
"join operator")
796+
.intConf
797+
.createWithDefault(Int.MaxValue)
783798

784799
val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD =
785800
buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold")
786801
.internal()
787-
.doc("Threshold for number of rows buffered in sort merge join operator")
802+
.doc("Threshold for number of rows to be spilled by sort merge join operator")
788803
.intConf
789-
.createWithDefault(Int.MaxValue)
804+
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
805+
806+
val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
807+
buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold")
808+
.internal()
809+
.doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " +
810+
"product operator")
811+
.intConf
812+
.createWithDefault(4096)
790813

791814
val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD =
792815
buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold")
793816
.internal()
794-
.doc("Threshold for number of rows buffered in cartesian product operator")
817+
.doc("Threshold for number of rows to be spilled by cartesian product operator")
795818
.intConf
796819
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
797820

@@ -1037,11 +1060,19 @@ class SQLConf extends Serializable with Logging {
10371060

10381061
def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER)
10391062

1063+
def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
1064+
10401065
def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)
10411066

1067+
def sortMergeJoinExecBufferInMemoryThreshold: Int =
1068+
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
1069+
10421070
def sortMergeJoinExecBufferSpillThreshold: Int =
10431071
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD)
10441072

1073+
def cartesianProductExecBufferInMemoryThreshold: Int =
1074+
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
1075+
10451076
def cartesianProductExecBufferSpillThreshold: Int =
10461077
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
10471078

sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ import org.apache.spark.storage.BlockManager
3131
import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
3232

3333
/**
34-
* An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
35-
* threshold of rows is reached.
34+
* An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array
35+
* until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which
36+
* would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is
37+
* excessive memory consumption). Setting these threshold involves following trade-offs:
3638
*
37-
* Setting spill threshold faces following trade-off:
38-
*
39-
* - If the spill threshold is too high, the in-memory array may occupy more memory than is
40-
* available, resulting in OOM.
41-
* - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
42-
* This may lead to a performance regression compared to the normal case of using an
43-
* [[ArrayBuffer]] or [[Array]].
39+
* - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory
40+
* than is available, resulting in OOM.
41+
* - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to
42+
* excessive disk writes. This may lead to a performance regression compared to the normal case
43+
* of using an [[ArrayBuffer]] or [[Array]].
4444
*/
4545
private[sql] class ExternalAppendOnlyUnsafeRowArray(
4646
taskMemoryManager: TaskMemoryManager,
@@ -49,21 +49,23 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
4949
taskContext: TaskContext,
5050
initialSize: Int,
5151
pageSizeBytes: Long,
52+
numRowsInMemoryBufferThreshold: Int,
5253
numRowsSpillThreshold: Int) extends Logging {
5354

54-
def this(numRowsSpillThreshold: Int) {
55+
def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) {
5556
this(
5657
TaskContext.get().taskMemoryManager(),
5758
SparkEnv.get.blockManager,
5859
SparkEnv.get.serializerManager,
5960
TaskContext.get(),
6061
1024,
6162
SparkEnv.get.memoryManager.pageSizeBytes,
63+
numRowsInMemoryBufferThreshold,
6264
numRowsSpillThreshold)
6365
}
6466

6567
private val initialSizeOfInMemoryBuffer =
66-
Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
68+
Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold)
6769

6870
private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
6971
new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
@@ -102,11 +104,11 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
102104
}
103105

104106
def add(unsafeRow: UnsafeRow): Unit = {
105-
if (numRows < numRowsSpillThreshold) {
107+
if (numRows < numRowsInMemoryBufferThreshold) {
106108
inMemoryBuffer += unsafeRow.copy()
107109
} else {
108110
if (spillableArray == null) {
109-
logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
111+
logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " +
110112
s"${classOf[UnsafeExternalSorter].getName}")
111113

112114
// We will not sort the rows, so prefixComparator and recordComparator are null

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ class UnsafeCartesianRDD(
3535
left : RDD[UnsafeRow],
3636
right : RDD[UnsafeRow],
3737
numFieldsOfRight: Int,
38+
inMemoryBufferThreshold: Int,
3839
spillThreshold: Int)
3940
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
4041

4142
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
42-
val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
43+
val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold)
4344

4445
val partition = split.asInstanceOf[CartesianPartition]
4546
rdd2.iterator(partition.s2, context).foreach(rowArray.add)
@@ -71,9 +72,12 @@ case class CartesianProductExec(
7172
val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
7273
val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
7374

74-
val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
75-
76-
val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
75+
val pair = new UnsafeCartesianRDD(
76+
leftResults,
77+
rightResults,
78+
right.output.size,
79+
sqlContext.conf.cartesianProductExecBufferInMemoryThreshold,
80+
sqlContext.conf.cartesianProductExecBufferSpillThreshold)
7781
pair.mapPartitionsWithIndexInternal { (index, iter) =>
7882
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
7983
val filtered = if (condition.isDefined) {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,14 @@ case class SortMergeJoinExec(
130130
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
131131
}
132132

133+
private def getInMemoryThreshold: Int = {
134+
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
135+
}
136+
133137
protected override def doExecute(): RDD[InternalRow] = {
134138
val numOutputRows = longMetric("numOutputRows")
135139
val spillThreshold = getSpillThreshold
140+
val inMemoryThreshold = getInMemoryThreshold
136141
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
137142
val boundCondition: (InternalRow) => Boolean = {
138143
condition.map { cond =>
@@ -158,6 +163,7 @@ case class SortMergeJoinExec(
158163
keyOrdering,
159164
RowIterator.fromScala(leftIter),
160165
RowIterator.fromScala(rightIter),
166+
inMemoryThreshold,
161167
spillThreshold
162168
)
163169
private[this] val joinRow = new JoinedRow
@@ -201,6 +207,7 @@ case class SortMergeJoinExec(
201207
keyOrdering,
202208
streamedIter = RowIterator.fromScala(leftIter),
203209
bufferedIter = RowIterator.fromScala(rightIter),
210+
inMemoryThreshold,
204211
spillThreshold
205212
)
206213
val rightNullRow = new GenericInternalRow(right.output.length)
@@ -214,6 +221,7 @@ case class SortMergeJoinExec(
214221
keyOrdering,
215222
streamedIter = RowIterator.fromScala(rightIter),
216223
bufferedIter = RowIterator.fromScala(leftIter),
224+
inMemoryThreshold,
217225
spillThreshold
218226
)
219227
val leftNullRow = new GenericInternalRow(left.output.length)
@@ -247,6 +255,7 @@ case class SortMergeJoinExec(
247255
keyOrdering,
248256
RowIterator.fromScala(leftIter),
249257
RowIterator.fromScala(rightIter),
258+
inMemoryThreshold,
250259
spillThreshold
251260
)
252261
private[this] val joinRow = new JoinedRow
@@ -281,6 +290,7 @@ case class SortMergeJoinExec(
281290
keyOrdering,
282291
RowIterator.fromScala(leftIter),
283292
RowIterator.fromScala(rightIter),
293+
inMemoryThreshold,
284294
spillThreshold
285295
)
286296
private[this] val joinRow = new JoinedRow
@@ -322,6 +332,7 @@ case class SortMergeJoinExec(
322332
keyOrdering,
323333
RowIterator.fromScala(leftIter),
324334
RowIterator.fromScala(rightIter),
335+
inMemoryThreshold,
325336
spillThreshold
326337
)
327338
private[this] val joinRow = new JoinedRow
@@ -420,8 +431,10 @@ case class SortMergeJoinExec(
420431
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
421432

422433
val spillThreshold = getSpillThreshold
434+
val inMemoryThreshold = getInMemoryThreshold
423435

424-
ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
436+
ctx.addMutableState(clsName, matches,
437+
s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);")
425438
// Copy the left keys as class members so they could be used in next function call.
426439
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
427440

@@ -626,14 +639,18 @@ case class SortMergeJoinExec(
626639
* @param streamedIter an input whose rows will be streamed.
627640
* @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
628641
* have the same join key.
642+
* @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by
643+
* internal buffer
644+
* @param spillThreshold Threshold for number of rows to be spilled by internal buffer
629645
*/
630646
private[joins] class SortMergeJoinScanner(
631647
streamedKeyGenerator: Projection,
632648
bufferedKeyGenerator: Projection,
633649
keyOrdering: Ordering[InternalRow],
634650
streamedIter: RowIterator,
635651
bufferedIter: RowIterator,
636-
bufferThreshold: Int) {
652+
inMemoryThreshold: Int,
653+
spillThreshold: Int) {
637654
private[this] var streamedRow: InternalRow = _
638655
private[this] var streamedRowKey: InternalRow = _
639656
private[this] var bufferedRow: InternalRow = _
@@ -644,7 +661,8 @@ private[joins] class SortMergeJoinScanner(
644661
*/
645662
private[this] var matchJoinKey: InternalRow = _
646663
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
647-
private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
664+
private[this] val bufferedMatches =
665+
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
648666

649667
// Initialization (note: do _not_ want to advance streamed here).
650668
advancedBufferedToRowWithNullFreeJoinKey()

sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ case class WindowExec(
282282
// Unwrap the expressions and factories from the map.
283283
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
284284
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
285+
val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold
285286
val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
286287

287288
// Start processing.
@@ -312,7 +313,8 @@ case class WindowExec(
312313
val inputFields = child.output.length
313314

314315
val buffer: ExternalAppendOnlyUnsafeRowArray =
315-
new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
316+
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
317+
316318
var bufferIterator: Iterator[UnsafeRow] = _
317319

318320
val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
665665

666666
test("test SortMergeJoin (with spill)") {
667667
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
668-
"spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
668+
"spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0",
669+
"spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") {
669670

670671
assertSpilled(sparkContext, "inner join") {
671672
checkAnswer(

sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark {
6767
benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
6868
var sum = 0L
6969
for (_ <- 0L until iterations) {
70-
val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
70+
val array = new ExternalAppendOnlyUnsafeRowArray(
71+
ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
72+
numSpillThreshold)
73+
7174
rows.foreach(x => array.add(x))
7275

7376
val iterator = array.generateIterator()
@@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark {
143146
benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
144147
var sum = 0L
145148
for (_ <- 0L until iterations) {
146-
val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
149+
val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold)
147150
rows.foreach(x => array.add(x))
148151

149152
val iterator = array.generateIterator()

0 commit comments

Comments
 (0)