From 6606910d2d386f36addc18173edc053b08d4df1c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Oct 2017 07:03:32 +0000 Subject: [PATCH 1/7] ClosureCleaner should fill referenced superclass fields. --- .../apache/spark/util/ClosureCleaner.scala | 67 ++++++++++++++++--- .../spark/util/ClosureCleanerSuite.scala | 27 ++++++++ 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 48a1d7b84b61..6a49e64df756 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -91,6 +91,52 @@ private[spark] object ClosureCleaner extends Logging { (seen - obj.getClass).toList } + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + accessedFields(cls) = Set.empty[String] + + var superClass = cls.getSuperclass() + while (superClass != null) { + accessedFields(superClass) = Set.empty[String] + superClass = superClass.getSuperclass() + } + } + } + + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) + } + } + + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) + setAccessedFields(outerClass, clone, obj, accessedFields) + + var superClass = outerClass.getSuperclass() + while (superClass != null) { + setAccessedFields(superClass, clone, obj, accessedFields) + superClass = superClass.getSuperclass() + } + clone + } + /** * Clean the given closure in place. * @@ -202,9 +248,8 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + populating accessed fields because this is the starting closure") // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later - for (cls <- outerClasses) { - accessedFields(cls) = Set.empty[String] - } + initAccessedFields(accessedFields, outerClasses) + // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. @@ -250,13 +295,8 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - field.set(clone, value) - } + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { @@ -397,6 +437,13 @@ private[util] class FieldAccessFinder( visitedMethods += m ClosureCleaner.getClassReader(cl).accept( new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + + var superClass = cl.getSuperclass() + while (superClass != null) { + ClosureCleaner.getClassReader(superClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + superClass = superClass.getSuperclass() + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4920b7ee8bfb..16d8a93702a0 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -119,6 +119,18 @@ class ClosureCleanerSuite extends SparkFunSuite { test("createNullValue") { new TestCreateNullValue().run() } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields") { + val concreteObject = () => new TestAbstractClass { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => + (n1, n2, s1, s2, d1, d2) + }.collect() + } + assert(concreteObject().run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } } // A non-serializable class we create in closures to make sure that we aren't @@ -377,3 +389,18 @@ class TestCreateNullValue { nestedClosure() } } + +abstract class TestAbstractClass extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d + + def rdd(sc: SparkContext): RDD[Int] = sc.parallelize(1 to 1) + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + body(rdd(sc)) + } + } + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] +} From 5ac7540fea210848dc8e1a30b51607a9bf5b0354 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Oct 2017 09:00:13 +0000 Subject: [PATCH 2/7] Add another test. --- .../spark/util/ClosureCleanerSuite.scala | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 16d8a93702a0..f7efdefc149f 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -120,7 +120,7 @@ class ClosureCleanerSuite extends SparkFunSuite { new TestCreateNullValue().run() } - test("SPARK-22328: ClosureCleaner misses referenced superclass fields") { + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { val concreteObject = () => new TestAbstractClass { val n2 = 222 val s2 = "bbb" @@ -131,6 +131,20 @@ class ClosureCleanerSuite extends SparkFunSuite { } assert(concreteObject().run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + val fn = () => new TestAbstractClass2 { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) + } + val concreteObject = fn() + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) + assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -404,3 +418,9 @@ abstract class TestAbstractClass extends Serializable { } def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] } + +abstract class TestAbstractClass2 extends Serializable { + val n1 = 111 + val s1 = "aaa" + val d1 = 1.0d +} From da747ca3085cff679dcba90ef7f001b1630943f2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Oct 2017 13:44:03 +0000 Subject: [PATCH 3/7] Refactoring tests. --- .../spark/util/ClosureCleanerSuite.scala | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index f7efdefc149f..4681645ff58c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -121,25 +121,32 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { - val concreteObject = () => new TestAbstractClass { + val concreteObject = new TestAbstractClass { val n2 = 222 val s2 = "bbb" val d2 = 2.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1) + body(rdd) + } + } + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => (n1, n2, s1, s2, d1, d2) }.collect() } - assert(concreteObject().run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { - val fn = () => new TestAbstractClass2 { + val concreteObject = new TestAbstractClass2 { val n2 = 222 val s2 = "bbb" val d2 = 2.0d def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) } - val concreteObject = fn() withSpark(new SparkContext("local", "test")) { sc => val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) @@ -409,18 +416,12 @@ abstract class TestAbstractClass extends Serializable { val s1 = "aaa" protected val d1 = 1.0d - def rdd(sc: SparkContext): RDD[Int] = sc.parallelize(1 to 1) - - def run(): Seq[(Int, Int, String, String, Double, Double)] = { - withSpark(new SparkContext("local", "test")) { sc => - body(rdd(sc)) - } - } + def run(): Seq[(Int, Int, String, String, Double, Double)] def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] } abstract class TestAbstractClass2 extends Serializable { val n1 = 111 val s1 = "aaa" - val d1 = 1.0d + protected val d1 = 1.0d } From 5d7efd14c0baba3e3f41258fcf6dc44f2976450a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Oct 2017 00:24:47 +0000 Subject: [PATCH 4/7] Address comment. --- .../apache/spark/util/ClosureCleaner.scala | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 6a49e64df756..c50347a2d9a9 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -96,13 +96,11 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]], outerClasses: Seq[Class[_]]): Unit = { for (cls <- outerClasses) { - accessedFields(cls) = Set.empty[String] - - var superClass = cls.getSuperclass() - while (superClass != null) { - accessedFields(superClass) = Set.empty[String] - superClass = superClass.getSuperclass() - } + var currentClass = cls + do { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() + } while (currentClass != null) } } @@ -127,13 +125,13 @@ private[spark] object ClosureCleaner extends Logging { outerClass: Class[_], accessedFields: Map[Class[_], Set[String]]): AnyRef = { val clone = instantiateClass(outerClass, parent) - setAccessedFields(outerClass, clone, obj, accessedFields) - var superClass = outerClass.getSuperclass() - while (superClass != null) { - setAccessedFields(superClass, clone, obj, accessedFields) - superClass = superClass.getSuperclass() - } + var currentClass = outerClass + do { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } while (currentClass != null) + clone } @@ -435,15 +433,13 @@ private[util] class FieldAccessFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - ClosureCleaner.getClassReader(cl).accept( - new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) - var superClass = cl.getSuperclass() - while (superClass != null) { - ClosureCleaner.getClassReader(superClass).accept( + var currentClass = cl + do { + ClosureCleaner.getClassReader(currentClass).accept( new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) - superClass = superClass.getSuperclass() - } + currentClass = currentClass.getSuperclass() + } while (currentClass != null) } } } From de5cbde1a2d337d545733b6b29568e418b9c4cfa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Oct 2017 14:35:27 +0000 Subject: [PATCH 5/7] Add a test. --- .../spark/util/ClosureCleanerSuite.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4681645ff58c..9a19baee9569 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -152,6 +152,30 @@ class ClosureCleanerSuite extends SparkFunSuite { assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) } } + + test("SPARK-22328: multiple outer classes have the same parent class") { + val concreteObject = new TestAbstractClass2 { + + val innerObject = new TestAbstractClass2 { + override val n1 = 222 + override val s1 = "bbb" + } + + val innerObject2 = new TestAbstractClass2 { + override val n1 = 444 + val n3 = 333 + val s3 = "ccc" + val d3 = 3.0d + + def getData: Int => (Int, Int, String, String, Double, Double, Int, String) = + _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1) + } + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) + assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb"))) + } + } } // A non-serializable class we create in closures to make sure that we aren't From 4d8f91e8917fd42644eecff0327e6e5dcc2f93b1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 06:57:01 +0000 Subject: [PATCH 6/7] Address comment. --- .../main/scala/org/apache/spark/util/ClosureCleaner.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index c50347a2d9a9..1611dd2082a5 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -97,10 +97,12 @@ private[spark] object ClosureCleaner extends Logging { outerClasses: Seq[Class[_]]): Unit = { for (cls <- outerClasses) { var currentClass = cls - do { + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { accessedFields(currentClass) = Set.empty[String] currentClass = currentClass.getSuperclass() - } while (currentClass != null) + } } } From e26d093bd14c79f26903206104da6aa57a32d613 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Oct 2017 01:37:43 +0000 Subject: [PATCH 7/7] Address comments. --- .../scala/org/apache/spark/util/ClosureCleaner.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 1611dd2082a5..dfece5dd0670 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -129,10 +129,12 @@ private[spark] object ClosureCleaner extends Logging { val clone = instantiateClass(outerClass, parent) var currentClass = outerClass - do { + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { setAccessedFields(currentClass, clone, obj, accessedFields) currentClass = currentClass.getSuperclass() - } while (currentClass != null) + } clone } @@ -437,11 +439,13 @@ private[util] class FieldAccessFinder( visitedMethods += m var currentClass = cl - do { + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { ClosureCleaner.getClassReader(currentClass).accept( new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) currentClass = currentClass.getSuperclass() - } while (currentClass != null) + } } } }