diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 7ebee9991220..00eb43291272 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1676,7 +1676,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal) + val cleanedFunc = clean(func) + runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal) } /** @@ -1730,7 +1731,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, + val cleanedFunc = clean(func) + val result = dagScheduler.runApproximateJob(rdd, cleanedFunc, evaluator, callSite, timeout, localProperties.get) logInfo( "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93d338fe0530..a6d5d2c94e17 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -131,7 +131,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) - combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) + // We will clean the combiner closure later in `combineByKey` + val cleanedSeqOp = self.context.clean(seqOp) + combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) } /** @@ -179,7 +181,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) - combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) + val cleanedFunc = self.context.clean(func) + combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7f7c7ed144eb..b3b60578c92e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -678,9 +678,13 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitions[U: ClassTag]( - f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + val cleanedF = sc.clean(f) + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter), + preservesPartitioning) } /** @@ -693,8 +697,11 @@ abstract class RDD[T: ClassTag]( def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) + val cleanedF = sc.clean(f) + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + preservesPartitioning) } /** @@ -1406,7 +1413,8 @@ abstract class RDD[T: ClassTag]( * Creates tuples of the elements in this RDD by applying `f`. */ def keyBy[K](f: T => K): RDD[(K, T)] = withScope { - map(x => (f(x), x)) + val cleanedF = sc.clean(f) + map(x => (cleanedF(x), x)) } /** A private method for tests, to look at the contents of each partition */ diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 4ac0382d8081..19fe6cb9dee2 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -312,7 +312,9 @@ private[spark] object ClosureCleaner extends Logging { private def ensureSerializable(func: AnyRef) { try { - SparkEnv.get.closureSerializer.newInstance().serialize(func) + if (SparkEnv.get != null) { + SparkEnv.get.closureSerializer.newInstance().serialize(func) + } } catch { case ex: Exception => throw new SparkException("Task not serializable", ex) } @@ -347,6 +349,9 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] class ReturnStatementInClosureException + extends SparkException("Return statements aren't allowed in Spark closures") + private class ReturnStatementFinder extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { @@ -354,7 +359,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { new MethodVisitor(ASM4) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { - throw new SparkException("Return statements aren't allowed in Spark closures") + throw new ReturnStatementInClosureException } } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index ff1bfe0774a2..446c3f24a74d 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark.util +import java.io.NotSerializableException + import org.scalatest.FunSuite import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{TaskContext, SparkContext, SparkException} +import org.apache.spark.partial.CountEvaluator +import org.apache.spark.rdd.RDD class ClosureCleanerSuite extends FunSuite { test("closures inside an object") { @@ -52,17 +56,66 @@ class ClosureCleanerSuite extends FunSuite { } test("toplevel return statements in closures are identified at cleaning time") { - val ex = intercept[SparkException] { + intercept[ReturnStatementInClosureException] { TestObjectWithBogusReturns.run() } - - assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures")) } test("return statements from named functions nested in closures don't raise exceptions") { val result = TestObjectWithNestedReturns.run() assert(result === 1) } + + test("user provided closures are actually cleaned") { + + // We use return statements as an indication that a closure is actually being cleaned + // We expect closure cleaner to find the return statements in the user provided closures + def expectCorrectException(body: => Unit): Unit = { + try { + body + } catch { + case rse: ReturnStatementInClosureException => // Success! + case e @ (_: NotSerializableException | _: SparkException) => + fail(s"Expected ReturnStatementInClosureException, but got $e.\n" + + "This means the closure provided by user is not actually cleaned.") + } + } + + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 10) + val pairRdd = rdd.map { i => (i, i) } + expectCorrectException { TestUserClosuresActuallyCleaned.testMap(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMap(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFilter(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testSortBy(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testGroupBy(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testForeach(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartition(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testReduce(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testTreeReduce(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFold(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testAggregate(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testTreeAggregate(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testCombineByKey(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartitionAsync(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob1(sc) } + expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob2(sc) } + expectCorrectException { TestUserClosuresActuallyCleaned.testRunApproximateJob(sc) } + expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) } + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -187,3 +240,90 @@ class TestClassWithNesting(val y: Int) extends Serializable { } } } + +/** + * Test whether closures passed in through public APIs are actually cleaned. + * + * We put a return statement in each of these closures as a mechanism to detect whether the + * ClosureCleaner actually cleaned our closure. If it did, then it would throw an appropriate + * exception explicitly complaining about the return statement. Otherwise, we know the + * ClosureCleaner did not actually clean our closure, in which case we should fail the test. + */ +private object TestUserClosuresActuallyCleaned { + def testMap(rdd: RDD[Int]): Unit = { rdd.map { _ => return; 0 }.count() } + def testFlatMap(rdd: RDD[Int]): Unit = { rdd.flatMap { _ => return; Seq() }.count() } + def testFilter(rdd: RDD[Int]): Unit = { rdd.filter { _ => return; true }.count() } + def testSortBy(rdd: RDD[Int]): Unit = { rdd.sortBy { _ => return; 1 }.count() } + def testKeyBy(rdd: RDD[Int]): Unit = { rdd.keyBy { _ => return; 1 }.count() } + def testGroupBy(rdd: RDD[Int]): Unit = { rdd.groupBy { _ => return; 1 }.count() } + def testMapPartitions(rdd: RDD[Int]): Unit = { rdd.mapPartitions { it => return; it }.count() } + def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = { + rdd.mapPartitionsWithIndex { (_, it) => return; it }.count() + } + def testZipPartitions2(rdd: RDD[Int]): Unit = { + rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count() + } + def testZipPartitions3(rdd: RDD[Int]): Unit = { + rdd.zipPartitions(rdd, rdd) { case (it1, it2, it3) => return; it1 }.count() + } + def testZipPartitions4(rdd: RDD[Int]): Unit = { + rdd.zipPartitions(rdd, rdd, rdd) { case (it1, it2, it3, it4) => return; it1 }.count() + } + def testForeach(rdd: RDD[Int]): Unit = { rdd.foreach { _ => return } } + def testForeachPartition(rdd: RDD[Int]): Unit = { rdd.foreachPartition { _ => return } } + def testReduce(rdd: RDD[Int]): Unit = { rdd.reduce { case (_, _) => return; 1 } } + def testTreeReduce(rdd: RDD[Int]): Unit = { rdd.treeReduce { case (_, _) => return; 1 } } + def testFold(rdd: RDD[Int]): Unit = { rdd.fold(0) { case (_, _) => return; 1 } } + def testAggregate(rdd: RDD[Int]): Unit = { + rdd.aggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }) + } + def testTreeAggregate(rdd: RDD[Int]): Unit = { + rdd.treeAggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }) + } + + // Test pair RDD functions + def testCombineByKey(rdd: RDD[(Int, Int)]): Unit = { + rdd.combineByKey( + { _ => return; 1 }: Int => Int, + { case (_, _) => return; 1 }: (Int, Int) => Int, + { case (_, _) => return; 1 }: (Int, Int) => Int + ).count() + } + def testAggregateByKey(rdd: RDD[(Int, Int)]): Unit = { + rdd.aggregateByKey(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }).count() + } + def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } } + def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } } + def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } } + def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } } + + // Test async RDD actions + def testForeachAsync(rdd: RDD[Int]): Unit = { rdd.foreachAsync { _ => return } } + def testForeachPartitionAsync(rdd: RDD[Int]): Unit = { rdd.foreachPartitionAsync { _ => return } } + + // Test SparkContext runJob + def testRunJob1(sc: SparkContext): Unit = { + val rdd = sc.parallelize(1 to 10, 10) + sc.runJob(rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1 } ) + } + def testRunJob2(sc: SparkContext): Unit = { + val rdd = sc.parallelize(1 to 10, 10) + sc.runJob(rdd, { iter: Iterator[Int] => return; 1 } ) + } + def testRunApproximateJob(sc: SparkContext): Unit = { + val rdd = sc.parallelize(1 to 10, 10) + val evaluator = new CountEvaluator(1, 0.5) + sc.runApproximateJob( + rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1L }, evaluator, 1000) + } + def testSubmitJob(sc: SparkContext): Unit = { + val rdd = sc.parallelize(1 to 10, 10) + sc.submitJob( + rdd, + { _ => return; 1 }: Iterator[Int] => Int, + Seq.empty, + { case (_, _) => return }: (Int, Int) => Unit, + { return } + ) + } +}