@@ -34,16 +34,28 @@ import org.scalatest.time.SpanSugar._
3434
3535import org .apache .spark .SparkContext ._
3636import org .apache .spark .rdd .RDD
37- import org .apache .spark .storage .{BlockId , BroadcastBlockId , RDDBlockId , ShuffleBlockId }
38-
39- class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
40-
37+ import org .apache .spark .storage ._
38+ import org .apache .spark .shuffle .hash .HashShuffleManager
39+ import org .apache .spark .shuffle .sort .SortShuffleManager
40+ import org .apache .spark .storage .BroadcastBlockId
41+ import org .apache .spark .storage .RDDBlockId
42+ import org .apache .spark .storage .ShuffleBlockId
43+ import org .apache .spark .storage .ShuffleIndexBlockId
44+
45+ /**
46+ * An abstract base class for context cleaner tests, which sets up a context with a config
47+ * suitable for cleaner tests and provides some utility functions. Subclasses can use different
48+ * config options, in particular, a different shuffle manager class
49+ */
50+ abstract class ContextCleanerSuiteBase (val shuffleManager : Class [_] = classOf [HashShuffleManager ])
51+ extends FunSuite with BeforeAndAfter with LocalSparkContext
52+ {
4153 implicit val defaultTimeout = timeout(10000 millis)
4254 val conf = new SparkConf ()
4355 .setMaster(" local[2]" )
4456 .setAppName(" ContextCleanerSuite" )
4557 .set(" spark.cleaner.referenceTracking.blocking" , " true" )
46- .set(" spark.shuffle.manager" , " org.apache.spark.shuffle.hash.HashShuffleManager " )
58+ .set(" spark.shuffle.manager" , shuffleManager.getName )
4759
4860 before {
4961 sc = new SparkContext (conf)
@@ -56,6 +68,59 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
5668 }
5769 }
5870
71+ // ------ Helper functions ------
72+
73+ protected def newRDD () = sc.makeRDD(1 to 10 )
74+ protected def newPairRDD () = newRDD().map(_ -> 1 )
75+ protected def newShuffleRDD () = newPairRDD().reduceByKey(_ + _)
76+ protected def newBroadcast () = sc.broadcast(1 to 100 )
77+
78+ protected def newRDDWithShuffleDependencies (): (RDD [_], Seq [ShuffleDependency [_, _, _]]) = {
79+ def getAllDependencies (rdd : RDD [_]): Seq [Dependency [_]] = {
80+ rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
81+ getAllDependencies(dep.rdd)
82+ }
83+ }
84+ val rdd = newShuffleRDD()
85+
86+ // Get all the shuffle dependencies
87+ val shuffleDeps = getAllDependencies(rdd)
88+ .filter(_.isInstanceOf [ShuffleDependency [_, _, _]])
89+ .map(_.asInstanceOf [ShuffleDependency [_, _, _]])
90+ (rdd, shuffleDeps)
91+ }
92+
93+ protected def randomRdd () = {
94+ val rdd : RDD [_] = Random .nextInt(3 ) match {
95+ case 0 => newRDD()
96+ case 1 => newShuffleRDD()
97+ case 2 => newPairRDD.join(newPairRDD())
98+ }
99+ if (Random .nextBoolean()) rdd.persist()
100+ rdd.count()
101+ rdd
102+ }
103+
104+ /** Run GC and make sure it actually has run */
105+ private def runGC () {
106+ val weakRef = new WeakReference (new Object ())
107+ val startTime = System .currentTimeMillis
108+ System .gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
109+ // Wait until a weak reference object has been GCed
110+ while (System .currentTimeMillis - startTime < 10000 && weakRef.get != null ) {
111+ System .gc()
112+ Thread .sleep(200 )
113+ }
114+ }
115+
116+ protected def cleaner = sc.cleaner.get
117+ }
118+
119+
120+ /**
121+ * Basic ContextCleanerSuite, which uses sort-based shuffle
122+ */
123+ class ContextCleanerSuite extends ContextCleanerSuiteBase {
59124 test(" cleanup RDD" ) {
60125 val rdd = newRDD().persist()
61126 val collected = rdd.collect().toList
@@ -181,7 +246,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
181246 .setMaster(" local-cluster[2, 1, 512]" )
182247 .setAppName(" ContextCleanerSuite" )
183248 .set(" spark.cleaner.referenceTracking.blocking" , " true" )
184- .set(" spark.shuffle.manager" , " org.apache.spark.shuffle.hash.HashShuffleManager " )
249+ .set(" spark.shuffle.manager" , shuffleManager.getName )
185250 sc = new SparkContext (conf2)
186251
187252 val numRdds = 10
@@ -212,57 +277,82 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
212277 case _ => false
213278 }, askSlaves = true ).isEmpty)
214279 }
280+ }
215281
216- // ------ Helper functions ------
217282
218- private def newRDD () = sc.makeRDD(1 to 10 )
219- private def newPairRDD () = newRDD().map(_ -> 1 )
220- private def newShuffleRDD () = newPairRDD().reduceByKey(_ + _)
221- private def newBroadcast () = sc.broadcast(1 to 100 )
283+ /**
284+ * A copy of the shuffle tests for sort-based shuffle
285+ */
286+ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase (classOf [SortShuffleManager ]) {
287+ test(" cleanup shuffle" ) {
288+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
289+ val collected = rdd.collect().toList
290+ val tester = new CleanerTester (sc, shuffleIds = shuffleDeps.map(_.shuffleId))
222291
223- private def newRDDWithShuffleDependencies (): (RDD [_], Seq [ShuffleDependency [_, _, _]]) = {
224- def getAllDependencies (rdd : RDD [_]): Seq [Dependency [_]] = {
225- rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
226- getAllDependencies(dep.rdd)
227- }
228- }
229- val rdd = newShuffleRDD()
292+ // Explicit cleanup
293+ shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true ))
294+ tester.assertCleanup()
230295
231- // Get all the shuffle dependencies
232- val shuffleDeps = getAllDependencies(rdd)
233- .filter(_.isInstanceOf [ShuffleDependency [_, _, _]])
234- .map(_.asInstanceOf [ShuffleDependency [_, _, _]])
235- (rdd, shuffleDeps)
296+ // Verify that shuffles can be re-executed after cleaning up
297+ assert(rdd.collect().toList.equals(collected))
236298 }
237299
238- private def randomRdd () = {
239- val rdd : RDD [_] = Random .nextInt(3 ) match {
240- case 0 => newRDD()
241- case 1 => newShuffleRDD()
242- case 2 => newPairRDD.join(newPairRDD())
243- }
244- if (Random .nextBoolean()) rdd.persist()
300+ test(" automatically cleanup shuffle" ) {
301+ var rdd = newShuffleRDD()
245302 rdd.count()
246- rdd
247- }
248303
249- private def randomBroadcast () = {
250- sc.broadcast(Random .nextInt(Int .MaxValue ))
304+ // Test that GC does not cause shuffle cleanup due to a strong reference
305+ val preGCTester = new CleanerTester (sc, shuffleIds = Seq (0 ))
306+ runGC()
307+ intercept[Exception ] {
308+ preGCTester.assertCleanup()(timeout(1000 millis))
309+ }
310+
311+ // Test that GC causes shuffle cleanup after dereferencing the RDD
312+ val postGCTester = new CleanerTester (sc, shuffleIds = Seq (0 ))
313+ rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
314+ runGC()
315+ postGCTester.assertCleanup()
251316 }
252317
253- /** Run GC and make sure it actually has run */
254- private def runGC () {
255- val weakRef = new WeakReference (new Object ())
256- val startTime = System .currentTimeMillis
257- System .gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
258- // Wait until a weak reference object has been GCed
259- while (System .currentTimeMillis - startTime < 10000 && weakRef.get != null ) {
260- System .gc()
261- Thread .sleep(200 )
318+ test(" automatically cleanup RDD + shuffle + broadcast in distributed mode" ) {
319+ sc.stop()
320+
321+ val conf2 = new SparkConf ()
322+ .setMaster(" local-cluster[2, 1, 512]" )
323+ .setAppName(" ContextCleanerSuite" )
324+ .set(" spark.cleaner.referenceTracking.blocking" , " true" )
325+ .set(" spark.shuffle.manager" , shuffleManager.getName)
326+ sc = new SparkContext (conf2)
327+
328+ val numRdds = 10
329+ val numBroadcasts = 4 // Broadcasts are more costly
330+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
331+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
332+ val rddIds = sc.persistentRdds.keys.toSeq
333+ val shuffleIds = 0 until sc.newShuffleId()
334+ val broadcastIds = broadcastBuffer.map(_.id)
335+
336+ val preGCTester = new CleanerTester (sc, rddIds, shuffleIds, broadcastIds)
337+ runGC()
338+ intercept[Exception ] {
339+ preGCTester.assertCleanup()(timeout(1000 millis))
262340 }
263- }
264341
265- private def cleaner = sc.cleaner.get
342+ // Test that GC triggers the cleanup of all variables after the dereferencing them
343+ val postGCTester = new CleanerTester (sc, rddIds, shuffleIds, broadcastIds)
344+ broadcastBuffer.clear()
345+ rddBuffer.clear()
346+ runGC()
347+ postGCTester.assertCleanup()
348+
349+ // Make sure the broadcasted task closure no longer exists after GC.
350+ val taskClosureBroadcastId = broadcastIds.max + 1
351+ assert(sc.env.blockManager.master.getMatchingBlockIds({
352+ case BroadcastBlockId (`taskClosureBroadcastId`, _) => true
353+ case _ => false
354+ }, askSlaves = true ).isEmpty)
355+ }
266356}
267357
268358
@@ -420,6 +510,7 @@ class CleanerTester(
420510 private def getShuffleBlocks (shuffleId : Int ): Seq [BlockId ] = {
421511 blockManager.master.getMatchingBlockIds( _ match {
422512 case ShuffleBlockId (`shuffleId`, _, _) => true
513+ case ShuffleIndexBlockId (`shuffleId`, _, _) => true
423514 case _ => false
424515 }, askSlaves = true )
425516 }
0 commit comments