diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b51737158098..2bd756fc13f6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -21,6 +21,8 @@ import org.apache.avro.reflect.Nullable; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -226,6 +228,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long keyPrefix; private int recordLength; private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -256,6 +259,14 @@ public boolean hasNext() { @Override public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 1d588c37c5db..a3f04dea0d3c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -22,6 +22,8 @@ import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; @@ -44,6 +46,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + private final TaskContext taskContext = TaskContext.get(); public UnsafeSorterSpillReader( SerializerManager serializerManager, @@ -73,6 +76,14 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 1314c94d42cf..b4deaf1c05bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable -import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SparkSession @@ -88,7 +88,15 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null - def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator() + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + if (context.isInterrupted()) { + throw new TaskKilledException + } + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } def next() = { val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 95058cc84209..a0afe3443cf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.{Partition, SparkContext, TaskContext, TaskKilledException} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -541,6 +541,13 @@ private[jdbc] class JDBCRDD( } override def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead and to minimize modified code since it's not easy to + // wrap this Iterator without re-indenting tons of code. + if (context.isInterrupted()) { + throw new TaskKilledException + } if (!finished) { if (!gotNext) { nextValue = getNext()