|
17 | 17 |
|
18 | 18 | package org.apache.spark.util |
19 | 19 |
|
| 20 | +import java.io.NotSerializableException |
| 21 | + |
20 | 22 | import org.scalatest.FunSuite |
21 | 23 |
|
22 | 24 | import org.apache.spark.LocalSparkContext._ |
23 | | -import org.apache.spark.{SparkContext, SparkException} |
| 25 | +import org.apache.spark.{TaskContext, SparkContext, SparkException} |
| 26 | +import org.apache.spark.partial.CountEvaluator |
| 27 | +import org.apache.spark.rdd.RDD |
24 | 28 |
|
25 | 29 | class ClosureCleanerSuite extends FunSuite { |
26 | 30 | test("closures inside an object") { |
@@ -52,17 +56,66 @@ class ClosureCleanerSuite extends FunSuite { |
52 | 56 | } |
53 | 57 |
|
54 | 58 | test("toplevel return statements in closures are identified at cleaning time") { |
55 | | - val ex = intercept[SparkException] { |
| 59 | + intercept[ReturnStatementInClosureException] { |
56 | 60 | TestObjectWithBogusReturns.run() |
57 | 61 | } |
58 | | - |
59 | | - assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures")) |
60 | 62 | } |
61 | 63 |
|
62 | 64 | test("return statements from named functions nested in closures don't raise exceptions") { |
63 | 65 | val result = TestObjectWithNestedReturns.run() |
64 | 66 | assert(result === 1) |
65 | 67 | } |
| 68 | + |
| 69 | + test("user provided closures are actually cleaned") { |
| 70 | + |
| 71 | + // We use return statements as an indication that a closure is actually being cleaned |
| 72 | + // We expect closure cleaner to find the return statements in the user provided closures |
| 73 | + def expectCorrectException(body: => Unit): Unit = { |
| 74 | + try { |
| 75 | + body |
| 76 | + } catch { |
| 77 | + case rse: ReturnStatementInClosureException => // Success! |
| 78 | + case e @ (_: NotSerializableException | _: SparkException) => |
| 79 | + fail(s"Expected ReturnStatementInClosureException, but got $e.\n" + |
| 80 | + "This means the closure provided by user is not actually cleaned.") |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + withSpark(new SparkContext("local", "test")) { sc => |
| 85 | + val rdd = sc.parallelize(1 to 10) |
| 86 | + val pairRdd = rdd.map { i => (i, i) } |
| 87 | + expectCorrectException { TestUserClosuresActuallyCleaned.testMap(rdd) } |
| 88 | + expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMap(rdd) } |
| 89 | + expectCorrectException { TestUserClosuresActuallyCleaned.testFilter(rdd) } |
| 90 | + expectCorrectException { TestUserClosuresActuallyCleaned.testSortBy(rdd) } |
| 91 | + expectCorrectException { TestUserClosuresActuallyCleaned.testGroupBy(rdd) } |
| 92 | + expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) } |
| 93 | + expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) } |
| 94 | + expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) } |
| 95 | + expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) } |
| 96 | + expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) } |
| 97 | + expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) } |
| 98 | + expectCorrectException { TestUserClosuresActuallyCleaned.testForeach(rdd) } |
| 99 | + expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartition(rdd) } |
| 100 | + expectCorrectException { TestUserClosuresActuallyCleaned.testReduce(rdd) } |
| 101 | + expectCorrectException { TestUserClosuresActuallyCleaned.testTreeReduce(rdd) } |
| 102 | + expectCorrectException { TestUserClosuresActuallyCleaned.testFold(rdd) } |
| 103 | + expectCorrectException { TestUserClosuresActuallyCleaned.testAggregate(rdd) } |
| 104 | + expectCorrectException { TestUserClosuresActuallyCleaned.testTreeAggregate(rdd) } |
| 105 | + expectCorrectException { TestUserClosuresActuallyCleaned.testCombineByKey(pairRdd) } |
| 106 | + expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) } |
| 107 | + expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) } |
| 108 | + expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) } |
| 109 | + expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) } |
| 110 | + expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) } |
| 111 | + expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) } |
| 112 | + expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartitionAsync(rdd) } |
| 113 | + expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob1(sc) } |
| 114 | + expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob2(sc) } |
| 115 | + expectCorrectException { TestUserClosuresActuallyCleaned.testRunApproximateJob(sc) } |
| 116 | + expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) } |
| 117 | + } |
| 118 | + } |
66 | 119 | } |
67 | 120 |
|
68 | 121 | // A non-serializable class we create in closures to make sure that we aren't |
@@ -187,3 +240,90 @@ class TestClassWithNesting(val y: Int) extends Serializable { |
187 | 240 | } |
188 | 241 | } |
189 | 242 | } |
| 243 | + |
| 244 | +/** |
| 245 | + * Test whether closures passed in through public APIs are actually cleaned. |
| 246 | + * |
| 247 | + * We put a return statement in each of these closures as a mechanism to detect whether the |
| 248 | + * ClosureCleaner actually cleaned our closure. If it did, then it would throw an appropriate |
| 249 | + * exception explicitly complaining about the return statement. Otherwise, we know the |
| 250 | + * ClosureCleaner did not actually clean our closure, in which case we should fail the test. |
| 251 | + */ |
| 252 | +private object TestUserClosuresActuallyCleaned { |
| 253 | + def testMap(rdd: RDD[Int]): Unit = { rdd.map { _ => return; 0 }.count() } |
| 254 | + def testFlatMap(rdd: RDD[Int]): Unit = { rdd.flatMap { _ => return; Seq() }.count() } |
| 255 | + def testFilter(rdd: RDD[Int]): Unit = { rdd.filter { _ => return; true }.count() } |
| 256 | + def testSortBy(rdd: RDD[Int]): Unit = { rdd.sortBy { _ => return; 1 }.count() } |
| 257 | + def testKeyBy(rdd: RDD[Int]): Unit = { rdd.keyBy { _ => return; 1 }.count() } |
| 258 | + def testGroupBy(rdd: RDD[Int]): Unit = { rdd.groupBy { _ => return; 1 }.count() } |
| 259 | + def testMapPartitions(rdd: RDD[Int]): Unit = { rdd.mapPartitions { it => return; it }.count() } |
| 260 | + def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = { |
| 261 | + rdd.mapPartitionsWithIndex { (_, it) => return; it }.count() |
| 262 | + } |
| 263 | + def testZipPartitions2(rdd: RDD[Int]): Unit = { |
| 264 | + rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count() |
| 265 | + } |
| 266 | + def testZipPartitions3(rdd: RDD[Int]): Unit = { |
| 267 | + rdd.zipPartitions(rdd, rdd) { case (it1, it2, it3) => return; it1 }.count() |
| 268 | + } |
| 269 | + def testZipPartitions4(rdd: RDD[Int]): Unit = { |
| 270 | + rdd.zipPartitions(rdd, rdd, rdd) { case (it1, it2, it3, it4) => return; it1 }.count() |
| 271 | + } |
| 272 | + def testForeach(rdd: RDD[Int]): Unit = { rdd.foreach { _ => return } } |
| 273 | + def testForeachPartition(rdd: RDD[Int]): Unit = { rdd.foreachPartition { _ => return } } |
| 274 | + def testReduce(rdd: RDD[Int]): Unit = { rdd.reduce { case (_, _) => return; 1 } } |
| 275 | + def testTreeReduce(rdd: RDD[Int]): Unit = { rdd.treeReduce { case (_, _) => return; 1 } } |
| 276 | + def testFold(rdd: RDD[Int]): Unit = { rdd.fold(0) { case (_, _) => return; 1 } } |
| 277 | + def testAggregate(rdd: RDD[Int]): Unit = { |
| 278 | + rdd.aggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }) |
| 279 | + } |
| 280 | + def testTreeAggregate(rdd: RDD[Int]): Unit = { |
| 281 | + rdd.treeAggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }) |
| 282 | + } |
| 283 | + |
| 284 | + // Test pair RDD functions |
| 285 | + def testCombineByKey(rdd: RDD[(Int, Int)]): Unit = { |
| 286 | + rdd.combineByKey( |
| 287 | + { _ => return; 1 }: Int => Int, |
| 288 | + { case (_, _) => return; 1 }: (Int, Int) => Int, |
| 289 | + { case (_, _) => return; 1 }: (Int, Int) => Int |
| 290 | + ).count() |
| 291 | + } |
| 292 | + def testAggregateByKey(rdd: RDD[(Int, Int)]): Unit = { |
| 293 | + rdd.aggregateByKey(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }).count() |
| 294 | + } |
| 295 | + def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } } |
| 296 | + def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } } |
| 297 | + def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } } |
| 298 | + def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } } |
| 299 | + |
| 300 | + // Test async RDD actions |
| 301 | + def testForeachAsync(rdd: RDD[Int]): Unit = { rdd.foreachAsync { _ => return } } |
| 302 | + def testForeachPartitionAsync(rdd: RDD[Int]): Unit = { rdd.foreachPartitionAsync { _ => return } } |
| 303 | + |
| 304 | + // Test SparkContext runJob |
| 305 | + def testRunJob1(sc: SparkContext): Unit = { |
| 306 | + val rdd = sc.parallelize(1 to 10, 10) |
| 307 | + sc.runJob(rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1 } ) |
| 308 | + } |
| 309 | + def testRunJob2(sc: SparkContext): Unit = { |
| 310 | + val rdd = sc.parallelize(1 to 10, 10) |
| 311 | + sc.runJob(rdd, { iter: Iterator[Int] => return; 1 } ) |
| 312 | + } |
| 313 | + def testRunApproximateJob(sc: SparkContext): Unit = { |
| 314 | + val rdd = sc.parallelize(1 to 10, 10) |
| 315 | + val evaluator = new CountEvaluator(1, 0.5) |
| 316 | + sc.runApproximateJob( |
| 317 | + rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1L }, evaluator, 1000) |
| 318 | + } |
| 319 | + def testSubmitJob(sc: SparkContext): Unit = { |
| 320 | + val rdd = sc.parallelize(1 to 10, 10) |
| 321 | + sc.submitJob( |
| 322 | + rdd, |
| 323 | + { _ => return; 1 }: Iterator[Int] => Int, |
| 324 | + Seq.empty, |
| 325 | + { case (_, _) => return }: (Int, Int) => Unit, |
| 326 | + { return } |
| 327 | + ) |
| 328 | + } |
| 329 | +} |
0 commit comments