Skip to content

Commit 16ffadc

Browse files
willbrxin
authored andcommitted
SPARK-571: forbid return statements in cleaned closures
This patch checks top-level closure arguments to `ClosureCleaner.clean` for `return` statements and raises an exception if it finds any. This is mainly a user-friendliness addition, since programs with return statements in closure arguments will currently fail upon RDD actions with a less-than-intuitive error message. Author: William Benton <[email protected]> Closes apache#717 from willb/spark-571 and squashes the following commits: c41eb7d [William Benton] Another test case for SPARK-571 30c42f4 [William Benton] Stylistic cleanups 559b16b [William Benton] Stylistic cleanups from review de13b79 [William Benton] Style fixes 295b6a5 [William Benton] Forbid return statements in closure arguments. b017c47 [William Benton] Added a test for SPARK-571
1 parent 52d9052 commit 16ffadc

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.collection.mutable.Set
2525
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
2626
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
2727

28-
import org.apache.spark.Logging
28+
import org.apache.spark.{Logging, SparkException}
2929

3030
private[spark] object ClosureCleaner extends Logging {
3131
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -108,6 +108,9 @@ private[spark] object ClosureCleaner extends Logging {
108108
val outerObjects = getOuterObjects(func)
109109

110110
val accessedFields = Map[Class[_], Set[String]]()
111+
112+
getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
113+
111114
for (cls <- outerClasses)
112115
accessedFields(cls) = Set[String]()
113116
for (cls <- func.getClass :: innerClasses)
@@ -180,6 +183,24 @@ private[spark] object ClosureCleaner extends Logging {
180183
}
181184
}
182185

186+
private[spark]
187+
class ReturnStatementFinder extends ClassVisitor(ASM4) {
188+
override def visitMethod(access: Int, name: String, desc: String,
189+
sig: String, exceptions: Array[String]): MethodVisitor = {
190+
if (name.contains("apply")) {
191+
new MethodVisitor(ASM4) {
192+
override def visitTypeInsn(op: Int, tp: String) {
193+
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
194+
throw new SparkException("Return statements aren't allowed in Spark closures")
195+
}
196+
}
197+
}
198+
} else {
199+
new MethodVisitor(ASM4) {}
200+
}
201+
}
202+
}
203+
183204
private[spark]
184205
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
185206
override def visitMethod(access: Int, name: String, desc: String,

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.util
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.LocalSparkContext._
23-
import org.apache.spark.SparkContext
23+
import org.apache.spark.{SparkContext, SparkException}
2424

2525
class ClosureCleanerSuite extends FunSuite {
2626
test("closures inside an object") {
@@ -50,6 +50,19 @@ class ClosureCleanerSuite extends FunSuite {
5050
val obj = new TestClassWithNesting(1)
5151
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
5252
}
53+
54+
test("toplevel return statements in closures are identified at cleaning time") {
55+
val ex = intercept[SparkException] {
56+
TestObjectWithBogusReturns.run()
57+
}
58+
59+
assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures"))
60+
}
61+
62+
test("return statements from named functions nested in closures don't raise exceptions") {
63+
val result = TestObjectWithNestedReturns.run()
64+
assert(result == 1)
65+
}
5366
}
5467

5568
// A non-serializable class we create in closures to make sure that we aren't
@@ -108,6 +121,30 @@ class TestClassWithoutFieldAccess {
108121
}
109122
}
110123

124+
object TestObjectWithBogusReturns {
125+
def run(): Int = {
126+
withSpark(new SparkContext("local", "test")) { sc =>
127+
val nums = sc.parallelize(Array(1, 2, 3, 4))
128+
// this return is invalid since it will transfer control outside the closure
129+
nums.map {x => return 1 ; x * 2}
130+
1
131+
}
132+
}
133+
}
134+
135+
object TestObjectWithNestedReturns {
136+
def run(): Int = {
137+
withSpark(new SparkContext("local", "test")) { sc =>
138+
val nums = sc.parallelize(Array(1, 2, 3, 4))
139+
nums.map {x =>
140+
// this return is fine since it will not transfer control outside the closure
141+
def foo(): Int = { return 5; 1 }
142+
foo()
143+
}
144+
1
145+
}
146+
}
147+
}
111148

112149
object TestObjectWithNesting {
113150
def run(): Int = {

0 commit comments

Comments
 (0)