Skip to content

Commit 014a9f9

Browse files
Andrew Orrxin
authored andcommitted
[SPARK-9709] [SQL] Avoid starving unsafe operators that use sort
The issue is that a task may run multiple sorts, and the sorts run by the child operator (i.e. parent RDD) may acquire all available memory such that other sorts in the same task do not have enough to proceed. This manifests itself in an `IOException("Unable to acquire X bytes of memory")` thrown by `UnsafeExternalSorter`. The solution is to reserve a page in each sorter in the chain before computing the child operator's (parent RDD's) partitions. This requires us to use a new special RDD that does some preparation before computing the parent's partitions. Author: Andrew Or <[email protected]> Closes #8011 from andrewor14/unsafe-starve-memory and squashes the following commits: 35b69a4 [Andrew Or] Simplify test 0b07782 [Andrew Or] Minor: update comments 5d5afdf [Andrew Or] Merge branch 'master' of github.com:apache/spark into unsafe-starve-memory 254032e [Andrew Or] Add tests 234acbd [Andrew Or] Reserve a page in sorter when preparing each partition b889e08 [Andrew Or] MapPartitionsWithPreparationRDD
1 parent b878253 commit 014a9f9

File tree

8 files changed

+184
-22
lines changed

8 files changed

+184
-22
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ private UnsafeExternalSorter(
138138
this.inMemSorter = existingInMemorySorter;
139139
}
140140

141+
// Acquire a new page as soon as we construct the sorter to ensure that we have at
142+
// least one page to work with. Otherwise, other operators in the same task may starve
143+
// this sorter (SPARK-9709).
144+
acquireNewPage();
145+
141146
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
142147
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
143148
// does not fully consume the sorter's output (e.g. sort followed by limit).
@@ -343,22 +348,32 @@ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
343348
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
344349
pageSizeBytes + ")");
345350
} else {
346-
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
347-
if (memoryAcquired < pageSizeBytes) {
348-
shuffleMemoryManager.release(memoryAcquired);
349-
spill();
350-
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
351-
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
352-
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
353-
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
354-
}
355-
}
356-
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
357-
currentPagePosition = currentPage.getBaseOffset();
358-
freeSpaceInCurrentPage = pageSizeBytes;
359-
allocatedPages.add(currentPage);
351+
acquireNewPage();
352+
}
353+
}
354+
}
355+
356+
/**
357+
* Acquire a new page from the {@link ShuffleMemoryManager}.
358+
*
359+
* If there is not enough space to allocate the new page, spill all existing ones
360+
* and try again. If there is still not enough space, report error to the caller.
361+
*/
362+
private void acquireNewPage() throws IOException {
363+
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
364+
if (memoryAcquired < pageSizeBytes) {
365+
shuffleMemoryManager.release(memoryAcquired);
366+
spill();
367+
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
368+
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
369+
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
370+
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
360371
}
361372
}
373+
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
374+
currentPagePosition = currentPage.getBaseOffset();
375+
freeSpaceInCurrentPage = pageSizeBytes;
376+
allocatedPages.add(currentPage);
362377
}
363378

364379
/**

core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import scala.reflect.ClassTag
2121

2222
import org.apache.spark.{Partition, TaskContext}
2323

24+
/**
25+
* An RDD that applies the provided function to every partition of the parent RDD.
26+
*/
2427
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
2528
prev: RDD[T],
2629
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.rdd
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.{Partition, Partitioner, TaskContext}
23+
24+
/**
25+
* An RDD that applies a user provided function to every partition of the parent RDD, and
26+
* additionally allows the user to prepare each partition before computing the parent partition.
27+
*/
28+
private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
29+
prev: RDD[T],
30+
preparePartition: () => M,
31+
executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
32+
preservesPartitioning: Boolean = false)
33+
extends RDD[U](prev) {
34+
35+
override val partitioner: Option[Partitioner] = {
36+
if (preservesPartitioning) firstParent[T].partitioner else None
37+
}
38+
39+
override def getPartitions: Array[Partition] = firstParent[T].partitions
40+
41+
/**
42+
* Prepare a partition before computing it from its parent.
43+
*/
44+
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
45+
val preparedArgument = preparePartition()
46+
val parentIterator = firstParent[T].iterator(partition, context)
47+
executePartition(context, partition.index, preparedArgument, parentIterator)
48+
}
49+
}

core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
124124
}
125125
}
126126

127-
private object ShuffleMemoryManager {
127+
private[spark] object ShuffleMemoryManager {
128128
/**
129129
* Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
130130
* of the memory pool and a safety factor since collections can sometimes grow bigger than

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ public void testPeakMemoryUsed() throws Exception {
340340
for (int i = 0; i < numRecordsPerPage * 10; i++) {
341341
insertNumber(sorter, i);
342342
newPeakMemory = sorter.getPeakMemoryUsedBytes();
343-
if (i % numRecordsPerPage == 0) {
343+
// The first page is pre-allocated on instantiation
344+
if (i % numRecordsPerPage == 0 && i > 0) {
344345
// We allocated a new page for this record, so peak memory should change
345346
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
346347
} else {
@@ -364,5 +365,21 @@ public void testPeakMemoryUsed() throws Exception {
364365
}
365366
}
366367

368+
@Test
369+
public void testReservePageOnInstantiation() throws Exception {
370+
final UnsafeExternalSorter sorter = newSorter();
371+
try {
372+
assertEquals(1, sorter.getNumberOfAllocatedPages());
373+
// Inserting a new record doesn't allocate more memory since we already have a page
374+
long peakMemory = sorter.getPeakMemoryUsedBytes();
375+
insertNumber(sorter, 100);
376+
assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
377+
assertEquals(1, sorter.getNumberOfAllocatedPages());
378+
} finally {
379+
sorter.cleanupResources();
380+
assertSpillFilesWereCleanedUp();
381+
}
382+
}
383+
367384
}
368385

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.rdd
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}
23+
24+
class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {
25+
26+
test("prepare called before parent partition is computed") {
27+
sc = new SparkContext("local", "test")
28+
29+
// Have the parent partition push a number to the list
30+
val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
31+
TestObject.things.append(20)
32+
iter
33+
}
34+
35+
// Push a different number during the prepare phase
36+
val preparePartition = () => { TestObject.things.append(10) }
37+
38+
// Push yet another number during the execution phase
39+
val executePartition = (
40+
taskContext: TaskContext,
41+
partitionIndex: Int,
42+
notUsed: Unit,
43+
parentIterator: Iterator[Int]) => {
44+
TestObject.things.append(30)
45+
TestObject.things.iterator
46+
}
47+
48+
// Verify that the numbers are pushed in the order expected
49+
val result = {
50+
new MapPartitionsWithPreparationRDD[Int, Int, Unit](
51+
parent, preparePartition, executePartition).collect()
52+
}
53+
assert(result === Array(10, 20, 30))
54+
}
55+
56+
}
57+
58+
private object TestObject {
59+
val things = new mutable.ListBuffer[Int]
60+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
158158
*/
159159
final def prepare(): Unit = {
160160
if (prepareCalled.compareAndSet(false, true)) {
161-
doPrepare
161+
doPrepare()
162162
children.foreach(_.prepare())
163163
}
164164
}

sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.{InternalAccumulator, TaskContext}
21-
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.errors._
2424
import org.apache.spark.sql.catalyst.expressions._
@@ -123,7 +123,12 @@ case class TungstenSort(
123123
val schema = child.schema
124124
val childOutput = child.output
125125
val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
126-
child.execute().mapPartitions({ iter =>
126+
127+
/**
128+
* Set up the sorter in each partition before computing the parent partition.
129+
* This makes sure our sorter is not starved by other sorters used in the same task.
130+
*/
131+
def preparePartition(): UnsafeExternalRowSorter = {
127132
val ordering = newOrdering(sortOrder, childOutput)
128133

129134
// The comparator for comparing prefix
@@ -143,12 +148,25 @@ case class TungstenSort(
143148
if (testSpillFrequency > 0) {
144149
sorter.setTestSpillFrequency(testSpillFrequency)
145150
}
146-
val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
147-
val taskContext = TaskContext.get()
151+
sorter
152+
}
153+
154+
/** Compute a partition using the sorter already set up previously. */
155+
def executePartition(
156+
taskContext: TaskContext,
157+
partitionIndex: Int,
158+
sorter: UnsafeExternalRowSorter,
159+
parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
160+
val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
148161
taskContext.internalMetricsToAccumulators(
149162
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
150163
sortedIterator
151-
}, preservesPartitioning = true)
164+
}
165+
166+
// Note: we need to set up the external sorter in each partition before computing
167+
// the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
168+
new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
169+
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
152170
}
153171

154172
}

0 commit comments

Comments
 (0)