From af65ef3e830d85a704889c16d0b229f5ab3966fa Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Tue, 31 Jul 2018 01:51:14 +0300 Subject: [PATCH] initial --- .../scala/org/apache/spark/TaskContext.scala | 5 +- .../spark/api/python/PythonRunner.scala | 2 +- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../scala/org/apache/spark/rdd/JdbcRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../spark/storage/memory/MemoryStore.scala | 2 +- .../apache/spark/util/ClosureCleaner.scala | 278 +++++++++++------- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../spark/util/ClosureCleanerSuite.scala | 3 + .../spark/util/ClosureCleanerSuite2.scala | 53 +++- .../spark/sql/avro/AvroFileFormat.scala | 2 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 2 +- .../spark/streaming/kafka010/KafkaRDD.scala | 2 +- .../ml/source/libsvm/LibSVMRelation.scala | 2 +- .../TungstenAggregationIterator.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 4 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../execution/datasources/CodecStreams.scala | 2 +- .../execution/datasources/FileScanRDD.scala | 2 +- .../datasources/csv/CSVDataSource.scala | 2 +- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/json/JsonDataSource.scala | 2 +- .../datasources/orc/OrcFileFormat.scala | 4 +- .../parquet/ParquetFileFormat.scala | 4 +- .../datasources/text/TextFileFormat.scala | 2 +- .../datasources/v2/DataSourceRDD.scala | 2 +- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../joins/ShuffledHashJoinExec.scala | 2 +- .../python/AggregateInPandasExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/execution/python/EvalPythonExec.scala | 2 +- .../python/PythonForeachWriter.scala | 2 +- .../execution/python/WindowInPandasExec.scala | 2 +- .../continuous/ContinuousCoalesceRDD.scala | 5 +- .../ContinuousQueuedDataReader.scala | 2 +- .../shuffle/ContinuousShuffleReadRDD.scala | 2 +- .../state/SymmetricHashJoinStateManager.scala | 2 +- .../execution/streaming/state/package.scala | 2 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 3 +- 43 files changed, 266 insertions(+), 161 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 69739745aa6cf..ceadf108c86cd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -123,7 +123,10 @@ abstract class TaskContext extends Serializable { * * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + def addTaskCompletionListener[U](f: (TaskContext) => U): TaskContext = { + // Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make + // this function polymorphic for every scala version >= 2.12, otherwise an overloaded method + // resolution error occurs at compile time. addTaskCompletionListener(new TaskCompletionListener { override def onTaskCompletion(context: TaskContext): Unit = f(context) }) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ebabedf950e39..7b31857588252 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -94,7 +94,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() if (!reuseWorker || !released.get) { try { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e125095cf4777..cbd49e070f2eb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -262,7 +262,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) val blockManager = SparkEnv.get.blockManager Option(TaskContext.get()) match { case Some(taskContext) => - taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + taskContext.addTaskCompletionListener[Unit](_ => blockManager.releaseLock(blockId)) case None => // This should only happen on the driver, where broadcast variables may be accessed // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 44895abc7bd4d..3974580cfaa11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -278,7 +278,7 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytes read before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index aab46b8954bf7..56ef3e107a980 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -77,7 +77,7 @@ class JdbcRDD[T: ClassTag]( override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] { - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ff66a04859d10..2d66d25ba39fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -214,7 +214,7 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytesRead before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 979152b55f957..8273d8a9eb476 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -300,7 +300,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) + context.addTaskCompletionListener[Unit](context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..74b0e0b3a741a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -104,7 +104,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) // Use completion callback to stop sorter if task was finished/cancelled. - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { sorter.stop() }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b31862323a895..00d01dd28afb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -346,7 +346,7 @@ final class ShuffleBlockFetcherIterator( private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. - context.addTaskCompletionListener(_ => cleanup()) + context.addTaskCompletionListener[Unit](_ => cleanup()) // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 4cc5bcb7f9baf..06fd56e54d9c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -827,7 +827,7 @@ private[storage] class PartiallySerializedBlock[T]( // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. // The dispose() method is idempotent, so it's safe to call it unconditionally. Option(TaskContext.get()).foreach { taskContext => - taskContext.addTaskCompletionListener { _ => + taskContext.addTaskCompletionListener[Unit] { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. unrolledBuffer.dispose() diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 073d71c63b0c7..d8c840c356527 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.invoke.SerializedLambda import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials @@ -33,6 +34,8 @@ import org.apache.spark.internal.Logging */ private[spark] object ClosureCleaner extends Logging { + private val isScala2_11 = scala.util.Properties.versionString.contains("2.11") + // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. @@ -159,6 +162,42 @@ private[spark] object ClosureCleaner extends Logging { clean(closure, checkSerializable, cleanTransitively, Map.empty) } + /** + * Try to get a serialized Lambda from the closure. + * + * @param closure the closure to check. + */ + private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { + if (isScala2_11) { + return None + } + val isClosureCandidate = + closure.getClass.isSynthetic && + closure + .getClass + .getInterfaces.exists(_.getName.equals("scala.Serializable")) + + if (isClosureCandidate) { + try { + Option(inspect(closure)) + } catch { + case e: Exception => + // no need to check if debug is enabled here the Spark + // logging api covers this. + logDebug("Closure is not a serialized lambda.", e) + None + } + } else { + None + } + } + + private def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] + } + /** * Helper method to clean the given closure in place. * @@ -206,7 +245,12 @@ private[spark] object ClosureCleaner extends Logging { cleanTransitively: Boolean, accessedFields: Map[Class[_], Set[String]]): Unit = { - if (!isClosure(func.getClass)) { + // most likely to be the case with 2.12, 2.13 + // so we check first + // non LMF-closures should be less frequent from now on + val lambdaFunc = getSerializedLambda(func) + + if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") return } @@ -218,118 +262,132 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") - - // A list of classes that represents closures enclosed in the given one - val innerClasses = getInnerClosureClasses(func) - - // A list of enclosing objects and their respective classes, from innermost to outermost - // An outer object at a given index is of type outer class at the same index - val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) - - // For logging purposes only - val declaredFields = func.getClass.getDeclaredFields - val declaredMethods = func.getClass.getDeclaredMethods - - if (log.isDebugEnabled) { - logDebug(" + declared fields: " + declaredFields.size) - declaredFields.foreach { f => logDebug(" " + f) } - logDebug(" + declared methods: " + declaredMethods.size) - declaredMethods.foreach { m => logDebug(" " + m) } - logDebug(" + inner classes: " + innerClasses.size) - innerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer classes: " + outerClasses.size) - outerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer objects: " + outerObjects.size) - outerObjects.foreach { o => logDebug(" " + o) } - } + if (lambdaFunc.isEmpty) { + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) + + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods + + if (log.isDebugEnabled) { + logDebug(s" + declared fields: ${declaredFields.size}") + declaredFields.foreach { f => logDebug(s" $f") } + logDebug(s" + declared methods: ${declaredMethods.size}") + declaredMethods.foreach { m => logDebug(s" $m") } + logDebug(s" + inner classes: ${innerClasses.size}") + innerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer classes: ${outerClasses.size}" ) + outerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer objects: ${outerObjects.size}") + outerObjects.foreach { o => logDebug(s" $o") } + } - // Fail fast if we detect return statements in closures - getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - // If accessed fields is not populated yet, we assume that - // the closure we are trying to clean is the starting one - if (accessedFields.isEmpty) { - logDebug(s" + populating accessed fields because this is the starting closure") - // Initialize accessed fields with the outer classes first - // This step is needed to associate the fields to the correct classes later - initAccessedFields(accessedFields, outerClasses) - - // Populate accessed fields by visiting all fields and methods accessed by this and - // all of its inner closures. If transitive cleaning is enabled, this may recursively - // visits methods that belong to other classes in search of transitively referenced fields. - for (cls <- func.getClass :: innerClasses) { - getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + // Fail fast if we detect return statements in closures + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + initAccessedFields(accessedFields, outerClasses) + + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } } - } - logDebug(s" + fields accessed by starting closure: " + accessedFields.size) - accessedFields.foreach { f => logDebug(" " + f) } - - // List of outer (class, object) pairs, ordered from outermost to innermost - // Note that all outer objects but the outermost one (first one in this list) must be closures - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var parent: AnyRef = null - if (outerPairs.size > 0) { - val (outermostClass, outermostObject) = outerPairs.head - if (isClosure(outermostClass)) { - logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") - } else if (outermostClass.getName.startsWith("$line")) { - // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it - // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. - logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => logDebug(" " + f) } + + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures + var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse + var parent: AnyRef = null + if (outerPairs.nonEmpty) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone + // and clean it as it may carray a lot of unnecessary information, + // e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object," + + "so do not clone it: " + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + - outerPairs.head) - parent = outermostObject // e.g. SparkContext - outerPairs = outerPairs.tail + logDebug(" + there are no enclosing objects!") } - } else { - logDebug(" + there are no enclosing objects!") - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - logDebug(s" + cloning the object $obj of class ${cls.getName}") - // We null out these unused references by cloning each object and then filling in all - // required fields from the original object. We need the parent here because the Java - // language specification requires the first constructor parameter of any closure to be - // its enclosing object. - val clone = cloneAndSetFields(parent, obj, cls, accessedFields) - - // If transitive cleaning is enabled, we recursively clean any enclosing closure using - // the already populated accessed fields map of the starting closure - if (cleanTransitively && isClosure(clone.getClass)) { - logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") - // No need to check serializable here for the outer closures because we're - // only interested in the serializability of the starting closure - clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + logDebug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + // No need to check serializable here for the outer closures because we're + // only interested in the serializability of the starting closure + clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + } + parent = clone } - parent = clone - } - // Update the parent pointer ($outer) of this closure - if (parent != null) { - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - // If the starting closure doesn't actually need our enclosing object, then just null it out - if (accessedFields.contains(func.getClass) && - !accessedFields(func.getClass).contains("$outer")) { - logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") - field.set(func, null) - } else { - // Update this closure's parent pointer to point to our enclosing object, - // which could either be a cloned closure or the original user object - field.set(func, parent) + // Update the parent pointer ($outer) of this closure + if (parent != null) { + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } } - } - logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + } else { + logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + + // scalastyle:off classforname + val captClass = Class.forName(lambdaFunc.get.getCapturingClass.replace('/', '.'), + false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + // Fail fast if we detect return statements in closures + getClassReader(captClass) + .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) + logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") + } if (checkSerializable) { ensureSerializable(func) @@ -366,14 +424,24 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM5) { +private class ReturnStatementFinder(targetMethodName: Option[String] = None) + extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { + // $anonfun$ covers Java 8 lambdas if (name.contains("apply") || name.contains("$anonfun$")) { + // A method with suffix "$adapted" will be generated in cases like + // { _:Int => return; Seq()} but not { _:Int => return; true} + // closure passed is $anonfun$t$1$adapted while actual code resides in $anonfun$s$1 + // visitor will see only $anonfun$s$1$adapted, so we remove the suffix, see + // https://github.com/scala/scala-dev/issues/109 + val isTargetMethod = targetMethodName.isEmpty || + name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") + new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { - if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { throw new ReturnStatementInClosureException } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5c6dd45ec58e3..d83da0d126d89 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -565,7 +565,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } - context.addTaskCompletionListener(context => cleanup()) + context.addTaskCompletionListener[Unit](context => cleanup()) } private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 9a19baee9569e..a0010f18c18a1 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -121,6 +121,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass { val n2 = 222 val s2 = "bbb" @@ -141,6 +142,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val n2 = 222 val s2 = "bbb" @@ -154,6 +156,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: multiple outer classes have the same parent class") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val innerObject = new TestAbstractClass2 { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 278fada83d78c..96da8ec3b2a1c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -145,6 +145,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get inner closure classes") { + assume(!ClosureCleanerSuite2.supportsLMFs) val closure1 = () => 1 val closure2 = () => { () => 1 } val closure3 = (i: Int) => { @@ -171,6 +172,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -207,6 +209,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -258,6 +261,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -296,6 +300,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -538,17 +543,22 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // As before, this closure is neither serializable nor cleanable verifyCleaning(inner1, serializableBefore = false, serializableAfter = false) - // This closure is no longer serializable because it now has a pointer to the outer closure, - // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. - // If we do not clean transitively, we will not null out this indirect reference. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = false, transitive = false) - - // If we clean transitively, we will find that method `a` does not actually reference the - // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out - // the outer closure's parent pointer. This will make `inner2` serializable. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = true, transitive = true) + if (ClosureCleanerSuite2.supportsLMFs) { + verifyCleaning( + inner2, serializableBefore = true, serializableAfter = true) + } else { + // This closure is no longer serializable because it now has a pointer to the outer closure, + // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. + // If we do not clean transitively, we will not null out this indirect reference. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = false, transitive = false) + + // If we clean transitively, we will find that method `a` does not actually reference the + // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out + // the outer closure's parent pointer. This will make `inner2` serializable. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = true, transitive = true) + } } // Same as above, but with more levels of nesting @@ -565,4 +575,25 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri test6()()() } + test("verify nested non-LMF closures") { + assume(ClosureCleanerSuite2.supportsLMFs) + class A1(val f: Int => Int) + class A2(val f: Int => Int => Int) + class B extends A1(x => x*x) + class C extends A2(x => new B().f ) + val closure1 = new B().f + val closure2 = new C().f + // serializable already + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + // brings in deps that can't be cleaned + verifyCleaning(closure2, serializableBefore = false, serializableAfter = false) + } +} + +object ClosureCleanerSuite2 { + // Scala 2.12 allows better interop with Java 8 via lambda syntax. This is supported + // by implementing FunctionN classes in Scala’s standard library as Single Abstract + // Method (SAM) types. Lambdas are implemented via the invokedynamic instruction and + // the use of the LambdaMwtaFactory (LMF) machanism. + val supportsLMFs = scala.util.Properties.versionString.contains("2.12") } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index e0159b9320276..657a2c20d828a 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -183,7 +183,7 @@ private[avro] class AvroFileFormat extends FileFormat // Ensure that the reader is closed even if the task fails or doesn't consume the entire // iterator of records. Option(TaskContext.get()).foreach { taskContext => - taskContext.addTaskCompletionListener { _ => + taskContext.addTaskCompletionListener[Unit] { _ => reader.close() } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 498e344ea39f4..53bd9a96d1d68 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -166,7 +166,7 @@ private[kafka010] class KafkaSourceRDD( } } // Release consumer, either by removing it or indicating we're no longer using it - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => underlying.closeIfNeeded() } underlying diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 3efc90fe466b2..4513dca44c7c6 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -237,7 +237,7 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - context.addTaskCompletionListener(_ => closeIfNeeded()) + context.addTaskCompletionListener[Unit](_ => closeIfNeeded()) val consumer = { KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 4e84ff044f55e..39dcd911a0814 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -154,7 +154,7 @@ private[libsvm] class LibSVMFileFormat (file: PartitionedFile) => { val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val points = linesReader .map(_.toString.trim) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index c1911235f8df3..72505f7fac0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -372,7 +372,7 @@ class TungstenAggregationIterator( } } - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { // At the end of the task, update the task's peak memory usage. Since we destroy // the map to create the sorter, their memory usages should not overlap, so it is safe // to just use the max of the two. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7487564ed64da..501520c0e085e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -86,7 +86,7 @@ private[sql] object ArrowConverters { val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => root.close() allocator.close() } @@ -137,7 +137,7 @@ private[sql] object ArrowConverters { private var schemaRead = StructType(Seq.empty) private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => closeReader() allocator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 6012aba1acbca..196d057c2de1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -97,7 +97,7 @@ case class InMemoryTableScanExec( columnarBatch.column(i).asInstanceOf[WritableColumnVector], columnarBatchSchema.fields(i).dataType, rowCount) } - taskContext.foreach(_.addTaskCompletionListener(_ => columnarBatch.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => columnarBatch.close())) columnarBatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index c0df6c779d7bd..9fddfad249e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -50,7 +50,7 @@ object CodecStreams { */ def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { val inputStream = createInputStream(config, path) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 28c36b6020d33..99fc78ff3e49b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -214,7 +214,7 @@ class FileScanRDD( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(_ => iterator.close()) + context.addTaskCompletionListener[Unit](_ => iterator.close()) iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 82322df407521..b7b46c7c86a29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -214,7 +214,7 @@ object TextInputCSVDataSource extends CSVDataSource { caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1b3b17c75e756..16b493892e3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -265,7 +265,7 @@ private[jdbc] class JDBCRDD( closed = true } - context.addTaskCompletionListener{ context => close() } + context.addTaskCompletionListener[Unit]{ context => close() } val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 2fee2128ba1f9..d6c588894d7f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -130,7 +130,7 @@ object TextInputJsonDataSource extends JsonDataSource { parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val textParser = parser.options.encoding .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index df1cebed5bd0a..4574f8247af54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -205,7 +205,7 @@ class OrcFileFormat // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( @@ -220,7 +220,7 @@ class OrcFileFormat val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2d4ac7686d4c4..283d7761d22d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -411,7 +411,7 @@ class ParquetFileFormat convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) val iter = new RecordReaderIterator(vectorizedReader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -432,7 +432,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) reader.initialize(split, hadoopAttemptContext) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 8661a5395ac44..268297148b522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -120,7 +120,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { new HadoopFileWholeTextReader(file, confValue) } - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 7ea53424ae100..782829887c446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -40,7 +40,7 @@ class DataSourceRDD[T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[T] = { val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition .createPartitionReader() - context.addTaskCompletionListener(_ => reader.close()) + context.addTaskCompletionListener[Unit](_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0396168d3f311..dab873bf9b9a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -214,7 +214,7 @@ trait HashJoin { } // At the end of the task, we update the avg hash probe. - TaskContext.get().addTaskCompletionListener(_ => + TaskContext.get().addTaskCompletionListener[Unit](_ => avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f32..2b59ed6e4d16b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -57,7 +57,7 @@ case class ShuffledHashJoinExec( buildTime += (System.nanoTime() - start) / 1000000 buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. - context.addTaskCompletionListener(_ => relation.close()) + context.addTaskCompletionListener[Unit](_ => relation.close()) relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index d00f6f042d6e0..88c9c026928e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -125,7 +125,7 @@ case class AggregateInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index ca665652f204d..85b187159a3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -131,7 +131,7 @@ class ArrowPythonRunner( private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => if (reader != null) { reader.close(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 860dc78c1dd1b..04c7dfdd4e204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -97,7 +97,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => queue.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index a58773122922f..f08f816cbcca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -56,7 +56,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override def open(partitionId: Long, version: Long): Boolean = { outputIterator // initialize everything - TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() } true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 628029b13a6c3..47bfbde56bb3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -142,7 +142,7 @@ case class WindowInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala index 10d8fc553fede..aec756c0eb2a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -40,7 +40,7 @@ case class ContinuousCoalesceRDDPartition( queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(endpointName, receiver) - TaskContext.get().addTaskCompletionListener { ctx => + TaskContext.get().addTaskCompletionListener[Unit] { ctx => env.stop(endpoint) } (receiver, endpoint) @@ -118,9 +118,8 @@ class ContinuousCoalesceRDD( } } - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => threadPool.shutdownNow() - () } part.writersInitialized = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index bfb87053db475..ec1dabd7da3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -70,7 +70,7 @@ class ContinuousQueuedDataReader( dataReaderThread.setDaemon(true) dataReaderThread.start() - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { this.close() }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 518223f3cd008..9b13f6398d837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -40,7 +40,7 @@ case class ContinuousShuffleReadPartition( queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(endpointName, receiver) - TaskContext.get().addTaskCompletionListener { ctx => + TaskContext.get().addTaskCompletionListener[Unit] { ctx => env.stop(endpoint) } (receiver, endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 6b386308c79fb..55d783e023246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -290,7 +290,7 @@ class SymmetricHashJoinStateManager( private val keyWithIndexToValue = new KeyWithIndexToValueStore() // Clean up any state store resources if necessary at the end of the task - Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } } + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } /** Helper trait for invoking common functionalities of a state store. */ private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 0b32327e51dbf..b6021438e902b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -61,7 +61,7 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { if (!store.hasCommitted) store.abort() }) cleanedF(store, iter) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 20090696ec3fc..de8085f07db19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -164,7 +164,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener[Unit](_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s OrcFileFormat.unwrapOrcStructs(