Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1676,7 +1676,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
val cleanedFunc = clean(func)
runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal)
}

/**
Expand Down Expand Up @@ -1730,7 +1731,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
val cleanedFunc = clean(func)
val result = dagScheduler.runApproximateJob(rdd, cleanedFunc, evaluator, callSite, timeout,
localProperties.get)
logInfo(
"Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))

combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
// We will clean the combiner closure later in `combineByKey`
val cleanedSeqOp = self.context.clean(seqOp)
combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner)
}

/**
Expand Down Expand Up @@ -179,7 +181,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))

combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
val cleanedFunc = self.context.clean(func)
combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No sure if needs to clean func. Since func will be passed to combineByKey, combineByKey will clean it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one's necessary because it's called inside another closure

}

/**
Expand Down
20 changes: 14 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,13 @@ abstract class RDD[T: ClassTag](
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope {
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
val cleanedF = sc.clean(f)
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
preservesPartitioning)
}

/**
Expand All @@ -693,8 +697,11 @@ abstract class RDD[T: ClassTag](
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
val cleanedF = sc.clean(f)
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
preservesPartitioning)
}

/**
Expand Down Expand Up @@ -1406,7 +1413,8 @@ abstract class RDD[T: ClassTag](
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: T => K): RDD[(K, T)] = withScope {
map(x => (f(x), x))
val cleanedF = sc.clean(f)
map(x => (cleanedF(x), x))
}

/** A private method for tests, to look at the contents of each partition */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ private[spark] object ClosureCleaner extends Logging {

private def ensureSerializable(func: AnyRef) {
try {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
if (SparkEnv.get != null) {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
}
} catch {
case ex: Exception => throw new SparkException("Task not serializable", ex)
}
Expand Down Expand Up @@ -347,14 +349,17 @@ private[spark] object ClosureCleaner extends Logging {
}
}

private[spark] class ReturnStatementInClosureException
extends SparkException("Return statements aren't allowed in Spark closures")

private class ReturnStatementFinder extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
if (name.contains("apply")) {
new MethodVisitor(ASM4) {
override def visitTypeInsn(op: Int, tp: String) {
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
throw new SparkException("Return statements aren't allowed in Spark closures")
throw new ReturnStatementInClosureException
}
}
}
Expand Down
148 changes: 144 additions & 4 deletions core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.util

import java.io.NotSerializableException

import org.scalatest.FunSuite

import org.apache.spark.LocalSparkContext._
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.{TaskContext, SparkContext, SparkException}
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.rdd.RDD

class ClosureCleanerSuite extends FunSuite {
test("closures inside an object") {
Expand Down Expand Up @@ -52,17 +56,66 @@ class ClosureCleanerSuite extends FunSuite {
}

test("toplevel return statements in closures are identified at cleaning time") {
val ex = intercept[SparkException] {
intercept[ReturnStatementInClosureException] {
TestObjectWithBogusReturns.run()
}

assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures"))
}

test("return statements from named functions nested in closures don't raise exceptions") {
val result = TestObjectWithNestedReturns.run()
assert(result === 1)
}

test("user provided closures are actually cleaned") {

// We use return statements as an indication that a closure is actually being cleaned
// We expect closure cleaner to find the return statements in the user provided closures
def expectCorrectException(body: => Unit): Unit = {
try {
body
} catch {
case rse: ReturnStatementInClosureException => // Success!
case e @ (_: NotSerializableException | _: SparkException) =>
fail(s"Expected ReturnStatementInClosureException, but got $e.\n" +
"This means the closure provided by user is not actually cleaned.")
}
}

withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 10)
val pairRdd = rdd.map { i => (i, i) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMap(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMap(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFilter(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testSortBy(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testGroupBy(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testForeach(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartition(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testReduce(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testTreeReduce(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFold(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testAggregate(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testTreeAggregate(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testCombineByKey(pairRdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartitionAsync(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob1(sc) }
expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob2(sc) }
expectCorrectException { TestUserClosuresActuallyCleaned.testRunApproximateJob(sc) }
expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) }
}
}
}

// A non-serializable class we create in closures to make sure that we aren't
Expand Down Expand Up @@ -187,3 +240,90 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}

/**
* Test whether closures passed in through public APIs are actually cleaned.
*
* We put a return statement in each of these closures as a mechanism to detect whether the
* ClosureCleaner actually cleaned our closure. If it did, then it would throw an appropriate
* exception explicitly complaining about the return statement. Otherwise, we know the
* ClosureCleaner did not actually clean our closure, in which case we should fail the test.
*/
private object TestUserClosuresActuallyCleaned {
def testMap(rdd: RDD[Int]): Unit = { rdd.map { _ => return; 0 }.count() }
def testFlatMap(rdd: RDD[Int]): Unit = { rdd.flatMap { _ => return; Seq() }.count() }
def testFilter(rdd: RDD[Int]): Unit = { rdd.filter { _ => return; true }.count() }
def testSortBy(rdd: RDD[Int]): Unit = { rdd.sortBy { _ => return; 1 }.count() }
def testKeyBy(rdd: RDD[Int]): Unit = { rdd.keyBy { _ => return; 1 }.count() }
def testGroupBy(rdd: RDD[Int]): Unit = { rdd.groupBy { _ => return; 1 }.count() }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No test for testGroupBy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops good catch. I just verified that all other test methods are called.

def testMapPartitions(rdd: RDD[Int]): Unit = { rdd.mapPartitions { it => return; it }.count() }
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
}
def testZipPartitions2(rdd: RDD[Int]): Unit = {
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
}
def testZipPartitions3(rdd: RDD[Int]): Unit = {
rdd.zipPartitions(rdd, rdd) { case (it1, it2, it3) => return; it1 }.count()
}
def testZipPartitions4(rdd: RDD[Int]): Unit = {
rdd.zipPartitions(rdd, rdd, rdd) { case (it1, it2, it3, it4) => return; it1 }.count()
}
def testForeach(rdd: RDD[Int]): Unit = { rdd.foreach { _ => return } }
def testForeachPartition(rdd: RDD[Int]): Unit = { rdd.foreachPartition { _ => return } }
def testReduce(rdd: RDD[Int]): Unit = { rdd.reduce { case (_, _) => return; 1 } }
def testTreeReduce(rdd: RDD[Int]): Unit = { rdd.treeReduce { case (_, _) => return; 1 } }
def testFold(rdd: RDD[Int]): Unit = { rdd.fold(0) { case (_, _) => return; 1 } }
def testAggregate(rdd: RDD[Int]): Unit = {
rdd.aggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
}
def testTreeAggregate(rdd: RDD[Int]): Unit = {
rdd.treeAggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
}

// Test pair RDD functions
def testCombineByKey(rdd: RDD[(Int, Int)]): Unit = {
rdd.combineByKey(
{ _ => return; 1 }: Int => Int,
{ case (_, _) => return; 1 }: (Int, Int) => Int,
{ case (_, _) => return; 1 }: (Int, Int) => Int
).count()
}
def testAggregateByKey(rdd: RDD[(Int, Int)]): Unit = {
rdd.aggregateByKey(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }).count()
}
def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } }
def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } }
def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } }
def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } }

// Test async RDD actions
def testForeachAsync(rdd: RDD[Int]): Unit = { rdd.foreachAsync { _ => return } }
def testForeachPartitionAsync(rdd: RDD[Int]): Unit = { rdd.foreachPartitionAsync { _ => return } }

// Test SparkContext runJob
def testRunJob1(sc: SparkContext): Unit = {
val rdd = sc.parallelize(1 to 10, 10)
sc.runJob(rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1 } )
}
def testRunJob2(sc: SparkContext): Unit = {
val rdd = sc.parallelize(1 to 10, 10)
sc.runJob(rdd, { iter: Iterator[Int] => return; 1 } )
}
def testRunApproximateJob(sc: SparkContext): Unit = {
val rdd = sc.parallelize(1 to 10, 10)
val evaluator = new CountEvaluator(1, 0.5)
sc.runApproximateJob(
rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1L }, evaluator, 1000)
}
def testSubmitJob(sc: SparkContext): Unit = {
val rdd = sc.parallelize(1 to 10, 10)
sc.submitJob(
rdd,
{ _ => return; 1 }: Iterator[Int] => Int,
Seq.empty,
{ case (_, _) => return }: (Int, Int) => Unit,
{ return }
)
}
}