From 08f9aa5a1dde58cf45f75b67ee88a9b978c3486c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 19 Dec 2016 11:27:26 -0800 Subject: [PATCH 1/3] Fix FileScanRDD interruption. --- .../sql/execution/datasources/FileScanRDD.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index e753cd962a1a..dced53613638 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession @@ -99,7 +99,15 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null - def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator() + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + if (context.isInterrupted()) { + throw new TaskKilledException + } + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } def next(): Object = { val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we From 2d43d5a501a162384d45eedf188ffed178e96351 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 19 Dec 2016 11:28:40 -0800 Subject: [PATCH 2/3] Fix for JDBCRDD interruption. --- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index d5b11e7bec0b..2bdc43254133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD( rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) - CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + new InterruptibleIterator(context, rowsIterator), close()) } } From 236efe580b52ed7ec6ce9e36f9b814147ebad99d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 19 Dec 2016 13:22:07 -0800 Subject: [PATCH 3/3] Make UnsafeSorterIterator interruptible. --- .../collection/unsafe/sort/UnsafeInMemorySorter.java | 11 +++++++++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 252a35ec6bdf..5b42843717e9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -22,6 +22,8 @@ import org.apache.avro.reflect.Nullable; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -253,6 +255,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long keyPrefix; private int recordLength; private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -283,6 +286,14 @@ public boolean hasNext() { @Override public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index a658e5eb47b7..b6323c624b7b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,6 +23,8 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + private final TaskContext taskContext = TaskContext.get(); public UnsafeSorterSpillReader( SerializerManager serializerManager, @@ -94,6 +97,14 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) {