Skip to content

Commit 1fdabf8

Browse files
author
Andrew Or
committed
[SPARK-7237] Many user provided closures are not actually cleaned
Note: ~140 lines are tests. In a nutshell, we never cleaned closures the user provided through the following operations: - sortBy - keyBy - mapPartitions - mapPartitionsWithIndex - aggregateByKey - foldByKey - foreachAsync - one of the aliases for runJob - runApproximateJob For more details on a reproduction and why they were not cleaned, please see [SPARK-7237](https://issues.apache.org/jira/browse/SPARK-7237). Author: Andrew Or <[email protected]> Closes apache#5787 from andrewor14/clean-more and squashes the following commits: 2f1f476 [Andrew Or] Merge branch 'master' of github.com:apache/spark into clean-more 7265865 [Andrew Or] Merge branch 'master' of github.com:apache/spark into clean-more df3caa3 [Andrew Or] Address comments 7a3cc80 [Andrew Or] Merge branch 'master' of github.com:apache/spark into clean-more 6498f44 [Andrew Or] Add missing test for groupBy e83699e [Andrew Or] Clean one more 8ac3074 [Andrew Or] Prevent NPE in tests when CC is used outside of an app 9ac5f9b [Andrew Or] Clean closures that are not currently cleaned 19e33b4 [Andrew Or] Add tests for all public RDD APIs that take in closures
1 parent d4cb38a commit 1fdabf8

File tree

5 files changed

+174
-16
lines changed

5 files changed

+174
-16
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,7 +1676,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
16761676
partitions: Seq[Int],
16771677
allowLocal: Boolean
16781678
): Array[U] = {
1679-
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
1679+
val cleanedFunc = clean(func)
1680+
runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal)
16801681
}
16811682

16821683
/**
@@ -1730,7 +1731,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
17301731
val callSite = getCallSite
17311732
logInfo("Starting job: " + callSite.shortForm)
17321733
val start = System.nanoTime
1733-
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
1734+
val cleanedFunc = clean(func)
1735+
val result = dagScheduler.runApproximateJob(rdd, cleanedFunc, evaluator, callSite, timeout,
17341736
localProperties.get)
17351737
logInfo(
17361738
"Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s")

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
131131
lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
132132
val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
133133

134-
combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
134+
// We will clean the combiner closure later in `combineByKey`
135+
val cleanedSeqOp = self.context.clean(seqOp)
136+
combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner)
135137
}
136138

137139
/**
@@ -179,7 +181,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
179181
lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
180182
val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
181183

182-
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
184+
val cleanedFunc = self.context.clean(func)
185+
combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner)
183186
}
184187

185188
/**

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,13 @@ abstract class RDD[T: ClassTag](
678678
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
679679
*/
680680
def mapPartitions[U: ClassTag](
681-
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope {
682-
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
683-
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
681+
f: Iterator[T] => Iterator[U],
682+
preservesPartitioning: Boolean = false): RDD[U] = withScope {
683+
val cleanedF = sc.clean(f)
684+
new MapPartitionsRDD(
685+
this,
686+
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
687+
preservesPartitioning)
684688
}
685689

686690
/**
@@ -693,8 +697,11 @@ abstract class RDD[T: ClassTag](
693697
def mapPartitionsWithIndex[U: ClassTag](
694698
f: (Int, Iterator[T]) => Iterator[U],
695699
preservesPartitioning: Boolean = false): RDD[U] = withScope {
696-
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
697-
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
700+
val cleanedF = sc.clean(f)
701+
new MapPartitionsRDD(
702+
this,
703+
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
704+
preservesPartitioning)
698705
}
699706

700707
/**
@@ -1406,7 +1413,8 @@ abstract class RDD[T: ClassTag](
14061413
* Creates tuples of the elements in this RDD by applying `f`.
14071414
*/
14081415
def keyBy[K](f: T => K): RDD[(K, T)] = withScope {
1409-
map(x => (f(x), x))
1416+
val cleanedF = sc.clean(f)
1417+
map(x => (cleanedF(x), x))
14101418
}
14111419

14121420
/** A private method for tests, to look at the contents of each partition */

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ private[spark] object ClosureCleaner extends Logging {
312312

313313
private def ensureSerializable(func: AnyRef) {
314314
try {
315-
SparkEnv.get.closureSerializer.newInstance().serialize(func)
315+
if (SparkEnv.get != null) {
316+
SparkEnv.get.closureSerializer.newInstance().serialize(func)
317+
}
316318
} catch {
317319
case ex: Exception => throw new SparkException("Task not serializable", ex)
318320
}
@@ -347,14 +349,17 @@ private[spark] object ClosureCleaner extends Logging {
347349
}
348350
}
349351

352+
private[spark] class ReturnStatementInClosureException
353+
extends SparkException("Return statements aren't allowed in Spark closures")
354+
350355
private class ReturnStatementFinder extends ClassVisitor(ASM4) {
351356
override def visitMethod(access: Int, name: String, desc: String,
352357
sig: String, exceptions: Array[String]): MethodVisitor = {
353358
if (name.contains("apply")) {
354359
new MethodVisitor(ASM4) {
355360
override def visitTypeInsn(op: Int, tp: String) {
356361
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
357-
throw new SparkException("Return statements aren't allowed in Spark closures")
362+
throw new ReturnStatementInClosureException
358363
}
359364
}
360365
}

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717

1818
package org.apache.spark.util
1919

20+
import java.io.NotSerializableException
21+
2022
import org.scalatest.FunSuite
2123

2224
import org.apache.spark.LocalSparkContext._
23-
import org.apache.spark.{SparkContext, SparkException}
25+
import org.apache.spark.{TaskContext, SparkContext, SparkException}
26+
import org.apache.spark.partial.CountEvaluator
27+
import org.apache.spark.rdd.RDD
2428

2529
class ClosureCleanerSuite extends FunSuite {
2630
test("closures inside an object") {
@@ -52,17 +56,66 @@ class ClosureCleanerSuite extends FunSuite {
5256
}
5357

5458
test("toplevel return statements in closures are identified at cleaning time") {
55-
val ex = intercept[SparkException] {
59+
intercept[ReturnStatementInClosureException] {
5660
TestObjectWithBogusReturns.run()
5761
}
58-
59-
assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures"))
6062
}
6163

6264
test("return statements from named functions nested in closures don't raise exceptions") {
6365
val result = TestObjectWithNestedReturns.run()
6466
assert(result === 1)
6567
}
68+
69+
test("user provided closures are actually cleaned") {
70+
71+
// We use return statements as an indication that a closure is actually being cleaned
72+
// We expect closure cleaner to find the return statements in the user provided closures
73+
def expectCorrectException(body: => Unit): Unit = {
74+
try {
75+
body
76+
} catch {
77+
case rse: ReturnStatementInClosureException => // Success!
78+
case e @ (_: NotSerializableException | _: SparkException) =>
79+
fail(s"Expected ReturnStatementInClosureException, but got $e.\n" +
80+
"This means the closure provided by user is not actually cleaned.")
81+
}
82+
}
83+
84+
withSpark(new SparkContext("local", "test")) { sc =>
85+
val rdd = sc.parallelize(1 to 10)
86+
val pairRdd = rdd.map { i => (i, i) }
87+
expectCorrectException { TestUserClosuresActuallyCleaned.testMap(rdd) }
88+
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMap(rdd) }
89+
expectCorrectException { TestUserClosuresActuallyCleaned.testFilter(rdd) }
90+
expectCorrectException { TestUserClosuresActuallyCleaned.testSortBy(rdd) }
91+
expectCorrectException { TestUserClosuresActuallyCleaned.testGroupBy(rdd) }
92+
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
93+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
94+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
95+
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
96+
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
97+
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
98+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeach(rdd) }
99+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartition(rdd) }
100+
expectCorrectException { TestUserClosuresActuallyCleaned.testReduce(rdd) }
101+
expectCorrectException { TestUserClosuresActuallyCleaned.testTreeReduce(rdd) }
102+
expectCorrectException { TestUserClosuresActuallyCleaned.testFold(rdd) }
103+
expectCorrectException { TestUserClosuresActuallyCleaned.testAggregate(rdd) }
104+
expectCorrectException { TestUserClosuresActuallyCleaned.testTreeAggregate(rdd) }
105+
expectCorrectException { TestUserClosuresActuallyCleaned.testCombineByKey(pairRdd) }
106+
expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) }
107+
expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) }
108+
expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) }
109+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) }
110+
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) }
111+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) }
112+
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartitionAsync(rdd) }
113+
expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob1(sc) }
114+
expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob2(sc) }
115+
expectCorrectException { TestUserClosuresActuallyCleaned.testRunApproximateJob(sc) }
116+
expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) }
117+
}
118+
}
66119
}
67120

68121
// A non-serializable class we create in closures to make sure that we aren't
@@ -187,3 +240,90 @@ class TestClassWithNesting(val y: Int) extends Serializable {
187240
}
188241
}
189242
}
243+
244+
/**
245+
* Test whether closures passed in through public APIs are actually cleaned.
246+
*
247+
* We put a return statement in each of these closures as a mechanism to detect whether the
248+
* ClosureCleaner actually cleaned our closure. If it did, then it would throw an appropriate
249+
* exception explicitly complaining about the return statement. Otherwise, we know the
250+
* ClosureCleaner did not actually clean our closure, in which case we should fail the test.
251+
*/
252+
private object TestUserClosuresActuallyCleaned {
253+
def testMap(rdd: RDD[Int]): Unit = { rdd.map { _ => return; 0 }.count() }
254+
def testFlatMap(rdd: RDD[Int]): Unit = { rdd.flatMap { _ => return; Seq() }.count() }
255+
def testFilter(rdd: RDD[Int]): Unit = { rdd.filter { _ => return; true }.count() }
256+
def testSortBy(rdd: RDD[Int]): Unit = { rdd.sortBy { _ => return; 1 }.count() }
257+
def testKeyBy(rdd: RDD[Int]): Unit = { rdd.keyBy { _ => return; 1 }.count() }
258+
def testGroupBy(rdd: RDD[Int]): Unit = { rdd.groupBy { _ => return; 1 }.count() }
259+
def testMapPartitions(rdd: RDD[Int]): Unit = { rdd.mapPartitions { it => return; it }.count() }
260+
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
261+
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
262+
}
263+
def testZipPartitions2(rdd: RDD[Int]): Unit = {
264+
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
265+
}
266+
def testZipPartitions3(rdd: RDD[Int]): Unit = {
267+
rdd.zipPartitions(rdd, rdd) { case (it1, it2, it3) => return; it1 }.count()
268+
}
269+
def testZipPartitions4(rdd: RDD[Int]): Unit = {
270+
rdd.zipPartitions(rdd, rdd, rdd) { case (it1, it2, it3, it4) => return; it1 }.count()
271+
}
272+
def testForeach(rdd: RDD[Int]): Unit = { rdd.foreach { _ => return } }
273+
def testForeachPartition(rdd: RDD[Int]): Unit = { rdd.foreachPartition { _ => return } }
274+
def testReduce(rdd: RDD[Int]): Unit = { rdd.reduce { case (_, _) => return; 1 } }
275+
def testTreeReduce(rdd: RDD[Int]): Unit = { rdd.treeReduce { case (_, _) => return; 1 } }
276+
def testFold(rdd: RDD[Int]): Unit = { rdd.fold(0) { case (_, _) => return; 1 } }
277+
def testAggregate(rdd: RDD[Int]): Unit = {
278+
rdd.aggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
279+
}
280+
def testTreeAggregate(rdd: RDD[Int]): Unit = {
281+
rdd.treeAggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
282+
}
283+
284+
// Test pair RDD functions
285+
def testCombineByKey(rdd: RDD[(Int, Int)]): Unit = {
286+
rdd.combineByKey(
287+
{ _ => return; 1 }: Int => Int,
288+
{ case (_, _) => return; 1 }: (Int, Int) => Int,
289+
{ case (_, _) => return; 1 }: (Int, Int) => Int
290+
).count()
291+
}
292+
def testAggregateByKey(rdd: RDD[(Int, Int)]): Unit = {
293+
rdd.aggregateByKey(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }).count()
294+
}
295+
def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } }
296+
def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } }
297+
def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } }
298+
def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } }
299+
300+
// Test async RDD actions
301+
def testForeachAsync(rdd: RDD[Int]): Unit = { rdd.foreachAsync { _ => return } }
302+
def testForeachPartitionAsync(rdd: RDD[Int]): Unit = { rdd.foreachPartitionAsync { _ => return } }
303+
304+
// Test SparkContext runJob
305+
def testRunJob1(sc: SparkContext): Unit = {
306+
val rdd = sc.parallelize(1 to 10, 10)
307+
sc.runJob(rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1 } )
308+
}
309+
def testRunJob2(sc: SparkContext): Unit = {
310+
val rdd = sc.parallelize(1 to 10, 10)
311+
sc.runJob(rdd, { iter: Iterator[Int] => return; 1 } )
312+
}
313+
def testRunApproximateJob(sc: SparkContext): Unit = {
314+
val rdd = sc.parallelize(1 to 10, 10)
315+
val evaluator = new CountEvaluator(1, 0.5)
316+
sc.runApproximateJob(
317+
rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1L }, evaluator, 1000)
318+
}
319+
def testSubmitJob(sc: SparkContext): Unit = {
320+
val rdd = sc.parallelize(1 to 10, 10)
321+
sc.submitJob(
322+
rdd,
323+
{ _ => return; 1 }: Iterator[Int] => Int,
324+
Seq.empty,
325+
{ case (_, _) => return }: (Int, Int) => Unit,
326+
{ return }
327+
)
328+
}
329+
}

0 commit comments

Comments
 (0)