Skip to content

Commit 0174149

Browse files
committed
Add cleanup behavior and cleanup tests for sort-based shuffle
This also required creating a BlockId subclass for shuffle index blocks so that the BlockManagers can report back their lists of blocks.
1 parent eb4ee0d commit 0174149

File tree

4 files changed

+176
-46
lines changed

4 files changed

+176
-46
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ private[spark] class SortShuffleWriter[K, V, C](
126126
out.close()
127127
}
128128

129+
// Register our map output with the ShuffleBlockManager, which handles cleaning it over time
130+
blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
131+
129132
mapStatus = new MapStatus(blockManager.blockManagerId,
130133
lengths.map(MapOutputTracker.compressSize))
131134
}

core/src/main/scala/org/apache/spark/storage/BlockId.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
5959
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
6060
}
6161

62+
@DeveloperApi
63+
case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
64+
extends BlockId {
65+
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
66+
}
67+
6268
@DeveloperApi
6369
case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
6470
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
@@ -88,6 +94,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
8894
object BlockId {
8995
val RDD = "rdd_([0-9]+)_([0-9]+)".r
9096
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
97+
val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
9198
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
9299
val TASKRESULT = "taskresult_([0-9]+)".r
93100
val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -99,6 +106,8 @@ object BlockId {
99106
RDDBlockId(rddId.toInt, splitIndex.toInt)
100107
case SHUFFLE(shuffleId, mapId, reduceId) =>
101108
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
109+
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
110+
ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
102111
case BROADCAST(broadcastId, field) =>
103112
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
104113
case TASKRESULT(taskId) =>

core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer
2828
import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
2929
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
3030
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
31+
import org.apache.spark.shuffle.sort.SortShuffleManager
3132

3233
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
3334
private[spark] trait ShuffleWriterGroup {
@@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup {
5859
* each block stored in each file. In order to find the location of a shuffle block, we search the
5960
* files within a ShuffleFileGroups associated with the block's reducer.
6061
*/
62+
// TODO: Factor this into a separate class for each ShuffleManager implementation
6163
private[spark]
6264
class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
6365
def conf = blockManager.conf
@@ -67,6 +69,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
6769
val consolidateShuffleFiles =
6870
conf.getBoolean("spark.shuffle.consolidateFiles", false)
6971

72+
// Are we using sort-based shuffle?
73+
val sortBasedShuffle =
74+
conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName
75+
7076
private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
7177

7278
/**
@@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
9197
private val metadataCleaner =
9298
new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
9399

100+
/**
101+
* Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle
102+
* because it just writes a single file by itself.
103+
*/
104+
def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = {
105+
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
106+
val shuffleState = shuffleStates(shuffleId)
107+
shuffleState.completedMapTasks.add(mapId)
108+
}
109+
110+
/**
111+
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
112+
* when the writers are closed successfully
113+
*/
94114
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
95115
new ShuffleWriterGroup {
96116
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
@@ -182,7 +202,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
182202
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
183203
shuffleStates.get(shuffleId) match {
184204
case Some(state) =>
185-
if (consolidateShuffleFiles) {
205+
if (sortBasedShuffle) {
206+
// There's a single block ID for each map, plus an index file for it
207+
for (mapId <- state.completedMapTasks) {
208+
val blockId = new ShuffleBlockId(shuffleId, mapId, 0)
209+
blockManager.diskBlockManager.getFile(blockId).delete()
210+
blockManager.diskBlockManager.getFile(blockId.name + ".index").delete()
211+
}
212+
} else if (consolidateShuffleFiles) {
186213
for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
187214
file.delete()
188215
}

core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala

Lines changed: 136 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,28 @@ import org.scalatest.time.SpanSugar._
3434

3535
import org.apache.spark.SparkContext._
3636
import org.apache.spark.rdd.RDD
37-
import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
38-
39-
class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
40-
37+
import org.apache.spark.storage._
38+
import org.apache.spark.shuffle.hash.HashShuffleManager
39+
import org.apache.spark.shuffle.sort.SortShuffleManager
40+
import org.apache.spark.storage.BroadcastBlockId
41+
import org.apache.spark.storage.RDDBlockId
42+
import org.apache.spark.storage.ShuffleBlockId
43+
import org.apache.spark.storage.ShuffleIndexBlockId
44+
45+
/**
46+
* An abstract base class for context cleaner tests, which sets up a context with a config
47+
* suitable for cleaner tests and provides some utility functions. Subclasses can use different
48+
* config options, in particular, a different shuffle manager class
49+
*/
50+
abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager])
51+
extends FunSuite with BeforeAndAfter with LocalSparkContext
52+
{
4153
implicit val defaultTimeout = timeout(10000 millis)
4254
val conf = new SparkConf()
4355
.setMaster("local[2]")
4456
.setAppName("ContextCleanerSuite")
4557
.set("spark.cleaner.referenceTracking.blocking", "true")
46-
.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
58+
.set("spark.shuffle.manager", shuffleManager.getName)
4759

4860
before {
4961
sc = new SparkContext(conf)
@@ -56,6 +68,59 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
5668
}
5769
}
5870

71+
//------ Helper functions ------
72+
73+
protected def newRDD() = sc.makeRDD(1 to 10)
74+
protected def newPairRDD() = newRDD().map(_ -> 1)
75+
protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
76+
protected def newBroadcast() = sc.broadcast(1 to 100)
77+
78+
protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
79+
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
80+
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
81+
getAllDependencies(dep.rdd)
82+
}
83+
}
84+
val rdd = newShuffleRDD()
85+
86+
// Get all the shuffle dependencies
87+
val shuffleDeps = getAllDependencies(rdd)
88+
.filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
89+
.map(_.asInstanceOf[ShuffleDependency[_, _, _]])
90+
(rdd, shuffleDeps)
91+
}
92+
93+
protected def randomRdd() = {
94+
val rdd: RDD[_] = Random.nextInt(3) match {
95+
case 0 => newRDD()
96+
case 1 => newShuffleRDD()
97+
case 2 => newPairRDD.join(newPairRDD())
98+
}
99+
if (Random.nextBoolean()) rdd.persist()
100+
rdd.count()
101+
rdd
102+
}
103+
104+
/** Run GC and make sure it actually has run */
105+
private def runGC() {
106+
val weakRef = new WeakReference(new Object())
107+
val startTime = System.currentTimeMillis
108+
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
109+
// Wait until a weak reference object has been GCed
110+
while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
111+
System.gc()
112+
Thread.sleep(200)
113+
}
114+
}
115+
116+
protected def cleaner = sc.cleaner.get
117+
}
118+
119+
120+
/**
121+
* Basic ContextCleanerSuite, which uses sort-based shuffle
122+
*/
123+
class ContextCleanerSuite extends ContextCleanerSuiteBase {
59124
test("cleanup RDD") {
60125
val rdd = newRDD().persist()
61126
val collected = rdd.collect().toList
@@ -181,7 +246,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
181246
.setMaster("local-cluster[2, 1, 512]")
182247
.setAppName("ContextCleanerSuite")
183248
.set("spark.cleaner.referenceTracking.blocking", "true")
184-
.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
249+
.set("spark.shuffle.manager", shuffleManager.getName)
185250
sc = new SparkContext(conf2)
186251

187252
val numRdds = 10
@@ -212,57 +277,82 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
212277
case _ => false
213278
}, askSlaves = true).isEmpty)
214279
}
280+
}
215281

216-
//------ Helper functions ------
217282

218-
private def newRDD() = sc.makeRDD(1 to 10)
219-
private def newPairRDD() = newRDD().map(_ -> 1)
220-
private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
221-
private def newBroadcast() = sc.broadcast(1 to 100)
283+
/**
284+
* A copy of the shuffle tests for sort-based shuffle
285+
*/
286+
class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) {
287+
test("cleanup shuffle") {
288+
val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
289+
val collected = rdd.collect().toList
290+
val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
222291

223-
private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
224-
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
225-
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
226-
getAllDependencies(dep.rdd)
227-
}
228-
}
229-
val rdd = newShuffleRDD()
292+
// Explicit cleanup
293+
shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
294+
tester.assertCleanup()
230295

231-
// Get all the shuffle dependencies
232-
val shuffleDeps = getAllDependencies(rdd)
233-
.filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
234-
.map(_.asInstanceOf[ShuffleDependency[_, _, _]])
235-
(rdd, shuffleDeps)
296+
// Verify that shuffles can be re-executed after cleaning up
297+
assert(rdd.collect().toList.equals(collected))
236298
}
237299

238-
private def randomRdd() = {
239-
val rdd: RDD[_] = Random.nextInt(3) match {
240-
case 0 => newRDD()
241-
case 1 => newShuffleRDD()
242-
case 2 => newPairRDD.join(newPairRDD())
243-
}
244-
if (Random.nextBoolean()) rdd.persist()
300+
test("automatically cleanup shuffle") {
301+
var rdd = newShuffleRDD()
245302
rdd.count()
246-
rdd
247-
}
248303

249-
private def randomBroadcast() = {
250-
sc.broadcast(Random.nextInt(Int.MaxValue))
304+
// Test that GC does not cause shuffle cleanup due to a strong reference
305+
val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
306+
runGC()
307+
intercept[Exception] {
308+
preGCTester.assertCleanup()(timeout(1000 millis))
309+
}
310+
311+
// Test that GC causes shuffle cleanup after dereferencing the RDD
312+
val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
313+
rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
314+
runGC()
315+
postGCTester.assertCleanup()
251316
}
252317

253-
/** Run GC and make sure it actually has run */
254-
private def runGC() {
255-
val weakRef = new WeakReference(new Object())
256-
val startTime = System.currentTimeMillis
257-
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
258-
// Wait until a weak reference object has been GCed
259-
while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
260-
System.gc()
261-
Thread.sleep(200)
318+
test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
319+
sc.stop()
320+
321+
val conf2 = new SparkConf()
322+
.setMaster("local-cluster[2, 1, 512]")
323+
.setAppName("ContextCleanerSuite")
324+
.set("spark.cleaner.referenceTracking.blocking", "true")
325+
.set("spark.shuffle.manager", shuffleManager.getName)
326+
sc = new SparkContext(conf2)
327+
328+
val numRdds = 10
329+
val numBroadcasts = 4 // Broadcasts are more costly
330+
val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
331+
val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
332+
val rddIds = sc.persistentRdds.keys.toSeq
333+
val shuffleIds = 0 until sc.newShuffleId()
334+
val broadcastIds = broadcastBuffer.map(_.id)
335+
336+
val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
337+
runGC()
338+
intercept[Exception] {
339+
preGCTester.assertCleanup()(timeout(1000 millis))
262340
}
263-
}
264341

265-
private def cleaner = sc.cleaner.get
342+
// Test that GC triggers the cleanup of all variables after the dereferencing them
343+
val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
344+
broadcastBuffer.clear()
345+
rddBuffer.clear()
346+
runGC()
347+
postGCTester.assertCleanup()
348+
349+
// Make sure the broadcasted task closure no longer exists after GC.
350+
val taskClosureBroadcastId = broadcastIds.max + 1
351+
assert(sc.env.blockManager.master.getMatchingBlockIds({
352+
case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
353+
case _ => false
354+
}, askSlaves = true).isEmpty)
355+
}
266356
}
267357

268358

@@ -420,6 +510,7 @@ class CleanerTester(
420510
private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
421511
blockManager.master.getMatchingBlockIds( _ match {
422512
case ShuffleBlockId(`shuffleId`, _, _) => true
513+
case ShuffleIndexBlockId(`shuffleId`, _, _) => true
423514
case _ => false
424515
}, askSlaves = true)
425516
}

0 commit comments

Comments
 (0)