Skip to content

Commit 19e33b4

Browse files
author
Andrew Or
committed
Add tests for all public RDD APIs that take in closures
Tests should fail as of this commit because the issue hasn't been fixed yet.
1 parent 1fd6ed9 commit 19e33b4

File tree

2 files changed

+128
-5
lines changed

2 files changed

+128
-5
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ private[spark] object ClosureCleaner extends Logging {
195195
}
196196
}
197197

198+
private[spark] class ReturnStatementInClosureException
199+
extends SparkException("Return statements aren't allowed in Spark closures")
200+
198201
private[spark]
199202
class ReturnStatementFinder extends ClassVisitor(ASM4) {
200203
override def visitMethod(access: Int, name: String, desc: String,
@@ -203,7 +206,7 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) {
203206
new MethodVisitor(ASM4) {
204207
override def visitTypeInsn(op: Int, tp: String) {
205208
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
206-
throw new SparkException("Return statements aren't allowed in Spark closures")
209+
throw new ReturnStatementInClosureException
207210
}
208211
}
209212
}

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.util
1919

20+
import java.io.NotSerializableException
21+
2022
import org.scalatest.FunSuite
2123

2224
import org.apache.spark.LocalSparkContext._
23-
import org.apache.spark.{SparkContext, SparkException}
25+
import org.apache.spark.{TaskContext, SparkContext, SparkException}
26+
import org.apache.spark.rdd.RDD
2427

2528
class ClosureCleanerSuite extends FunSuite {
2629
test("closures inside an object") {
@@ -52,17 +55,63 @@ class ClosureCleanerSuite extends FunSuite {
5255
}
5356

5457
test("toplevel return statements in closures are identified at cleaning time") {
55-
val ex = intercept[SparkException] {
58+
intercept[ReturnStatementInClosureException] {
5659
TestObjectWithBogusReturns.run()
5760
}
58-
59-
assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures"))
6061
}
6162

6263
test("return statements from named functions nested in closures don't raise exceptions") {
6364
val result = TestObjectWithNestedReturns.run()
6465
assert(result == 1)
6566
}
67+
68+
test("user provided closures are actually cleaned") {
69+
70+
// We use return statements as an indication that a closure is actually being cleaned
71+
// We expect closure cleaner to find the return statements in the user provided closures
72+
def expectCorrectException(body: => Unit): Unit = {
73+
try {
74+
body
75+
} catch {
76+
case rse: ReturnStatementInClosureException => // Success!
77+
case e @ (_: NotSerializableException | _: SparkException) =>
78+
fail(s"Expected ReturnStatementInClosureException, but got $e.\n" +
79+
"This means the closure provided by user is not actually cleaned.")
80+
}
81+
}
82+
83+
withSpark(new SparkContext("local", "test")) { sc =>
84+
val rdd = sc.parallelize(1 to 10)
85+
val pairRdd = rdd.map { i => (i, i) }
86+
expectCorrectException { TestUserClosuresActuallyCleaned.testMap(rdd) }
87+
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMap(rdd) }
88+
expectCorrectException { TestUserClosuresActuallyCleaned.testFilter(rdd) }
89+
expectCorrectException { TestUserClosuresActuallyCleaned.testSortBy(rdd) }
90+
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
91+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
92+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
93+
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
94+
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
95+
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
96+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeach(rdd) }
97+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartition(rdd) }
98+
expectCorrectException { TestUserClosuresActuallyCleaned.testReduce(rdd) }
99+
expectCorrectException { TestUserClosuresActuallyCleaned.testTreeReduce(rdd) }
100+
expectCorrectException { TestUserClosuresActuallyCleaned.testFold(rdd) }
101+
expectCorrectException { TestUserClosuresActuallyCleaned.testAggregate(rdd) }
102+
expectCorrectException { TestUserClosuresActuallyCleaned.testTreeAggregate(rdd) }
103+
expectCorrectException { TestUserClosuresActuallyCleaned.testCombineByKey(pairRdd) }
104+
expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) }
105+
expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) }
106+
expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) }
107+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) }
108+
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) }
109+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) }
110+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartitionAsync(rdd) }
111+
expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob1(sc) }
112+
expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob2(sc) }
113+
}
114+
}
66115
}
67116

68117
// A non-serializable class we create in closures to make sure that we aren't
@@ -180,3 +229,74 @@ class TestClassWithNesting(val y: Int) extends Serializable {
180229
}
181230
}
182231
}
232+
233+
/**
234+
* Test whether closures passed in through public APIs are actually cleaned.
235+
*
236+
* We put a return statement in each of these closures as a mechanism to detect whether the
237+
* ClosureCleaner actually cleaned our closure. If it did, then it would throw an appropriate
238+
* exception explicitly complaining about the return statement. Otherwise, we know the
239+
* ClosureCleaner did not actually clean our closure, in which case we should fail the test.
240+
*/
241+
private object TestUserClosuresActuallyCleaned {
242+
def testMap(rdd: RDD[Int]): Unit = { rdd.map { _ => return; 0 }.count() }
243+
def testFlatMap(rdd: RDD[Int]): Unit = { rdd.flatMap { _ => return; Seq() }.count() }
244+
def testFilter(rdd: RDD[Int]): Unit = { rdd.filter { _ => return; true }.count() }
245+
def testSortBy(rdd: RDD[Int]): Unit = { rdd.sortBy { _ => return; 1 }.count() }
246+
def testKeyBy(rdd: RDD[Int]): Unit = { rdd.keyBy { _ => return; 1 }.count() }
247+
def testGroupBy(rdd: RDD[Int]): Unit = { rdd.groupBy { _ => return; 1 }.count() }
248+
def testMapPartitions(rdd: RDD[Int]): Unit = { rdd.mapPartitions { it => return; it }.count() }
249+
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
250+
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
251+
}
252+
def testZipPartitions2(rdd: RDD[Int]): Unit = {
253+
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
254+
}
255+
def testZipPartitions3(rdd: RDD[Int]): Unit = {
256+
rdd.zipPartitions(rdd, rdd) { case (it1, it2, it3) => return; it1 }.count()
257+
}
258+
def testZipPartitions4(rdd: RDD[Int]): Unit = {
259+
rdd.zipPartitions(rdd, rdd, rdd) { case (it1, it2, it3, it4) => return; it1 }.count()
260+
}
261+
def testForeach(rdd: RDD[Int]): Unit = { rdd.foreach { _ => return } }
262+
def testForeachPartition(rdd: RDD[Int]): Unit = { rdd.foreachPartition { _ => return } }
263+
def testReduce(rdd: RDD[Int]): Unit = { rdd.reduce { case (_, _) => return; 1 } }
264+
def testTreeReduce(rdd: RDD[Int]): Unit = { rdd.treeReduce { case (_, _) => return; 1 } }
265+
def testFold(rdd: RDD[Int]): Unit = { rdd.fold(0) { case (_, _) => return; 1 } }
266+
def testAggregate(rdd: RDD[Int]): Unit = {
267+
rdd.aggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
268+
}
269+
def testTreeAggregate(rdd: RDD[Int]): Unit = {
270+
rdd.treeAggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
271+
}
272+
273+
// Test pair RDD functions
274+
def testCombineByKey(rdd: RDD[(Int, Int)]): Unit = {
275+
rdd.combineByKey(
276+
{ _ => return; 1 }: Int => Int,
277+
{ case (_, _) => return; 1 }: (Int, Int) => Int,
278+
{ case (_, _) => return; 1 }: (Int, Int) => Int
279+
).count()
280+
}
281+
def testAggregateByKey(rdd: RDD[(Int, Int)]): Unit = {
282+
rdd.aggregateByKey(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }).count()
283+
}
284+
def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } }
285+
def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } }
286+
def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } }
287+
def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } }
288+
289+
// Test async RDD actions
290+
def testForeachAsync(rdd: RDD[Int]): Unit = { rdd.foreachAsync { _ => return } }
291+
def testForeachPartitionAsync(rdd: RDD[Int]): Unit = { rdd.foreachPartitionAsync { _ => return } }
292+
293+
// Test SparkContext runJob
294+
def testRunJob1(sc: SparkContext): Unit = {
295+
val rdd = sc.parallelize(1 to 10, 10)
296+
sc.runJob(rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1 } )
297+
}
298+
def testRunJob2(sc: SparkContext): Unit = {
299+
val rdd = sc.parallelize(1 to 10, 10)
300+
sc.runJob(rdd, { iter: Iterator[Int] => return; 1 } )
301+
}
302+
}

0 commit comments

Comments
 (0)