Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.mutable.Set
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.Logging
import org.apache.spark.{Logging, SparkException}

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -108,6 +108,9 @@ private[spark] object ClosureCleaner extends Logging {
val outerObjects = getOuterObjects(func)

val accessedFields = Map[Class[_], Set[String]]()

getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)

for (cls <- outerClasses)
accessedFields(cls) = Set[String]()
for (cls <- func.getClass :: innerClasses)
Expand Down Expand Up @@ -180,6 +183,24 @@ private[spark] object ClosureCleaner extends Logging {
}
}

private[spark]
class ReturnStatementFinder extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
if (name.contains("apply")) {
new MethodVisitor(ASM4) {
override def visitTypeInsn(op: Int, tp: String) {
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
throw new SparkException("Return statements aren't allowed in Spark closures")
}
}
}
} else {
new MethodVisitor(ASM4) {}
}
}
}

private[spark]
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.util
import org.scalatest.FunSuite

import org.apache.spark.LocalSparkContext._
import org.apache.spark.SparkContext
import org.apache.spark.{SparkContext, SparkException}

class ClosureCleanerSuite extends FunSuite {
test("closures inside an object") {
Expand Down Expand Up @@ -50,6 +50,19 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}

test("toplevel return statements in closures are identified at cleaning time") {
val ex = intercept[SparkException] {
TestObjectWithBogusReturns.run()
}

assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures"))
}

test("return statements from named functions nested in closures don't raise exceptions") {
val result = TestObjectWithNestedReturns.run()
assert(result == 1)
}
}

// A non-serializable class we create in closures to make sure that we aren't
Expand Down Expand Up @@ -108,6 +121,30 @@ class TestClassWithoutFieldAccess {
}
}

object TestObjectWithBogusReturns {
def run(): Int = {
withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
// this return is invalid since it will transfer control outside the closure
nums.map {x => return 1 ; x * 2}
1
}
}
}

object TestObjectWithNestedReturns {
def run(): Int = {
withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map {x =>
// this return is fine since it will not transfer control outside the closure
def foo(): Int = { return 5; 1 }
foo()
}
1
}
}
}

object TestObjectWithNesting {
def run(): Int = {
Expand Down