-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-18928] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter #16340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| import com.google.common.io.Closeables; | ||
|
|
||
| import org.apache.spark.SparkEnv; | ||
| import org.apache.spark.TaskContext; | ||
| import org.apache.spark.TaskKilledException; | ||
| import org.apache.spark.io.NioBufferedFileInputStream; | ||
| import org.apache.spark.serializer.SerializerManager; | ||
| import org.apache.spark.storage.BlockId; | ||
|
|
@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen | |
| private byte[] arr = new byte[1024 * 1024]; | ||
| private Object baseObject = arr; | ||
| private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; | ||
| private final TaskContext taskContext = TaskContext.get(); | ||
|
|
||
| public UnsafeSorterSpillReader( | ||
| SerializerManager serializerManager, | ||
|
|
@@ -94,6 +97,14 @@ public boolean hasNext() { | |
|
|
||
| @Override | ||
| public void loadNext() throws IOException { | ||
| // Kill the task in case it has been marked as killed. This logic is from | ||
| // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order | ||
| // to avoid performance overhead. This check is added here in `loadNext()` instead of in | ||
| // `hasNext()` because it's technically possible for the caller to be relying on | ||
| // `getNumRecords()` instead of `hasNext()` to know when to stop. | ||
| if (taskContext != null && taskContext.isInterrupted()) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TaskContext can be null in case this is used on the driver outside of the context of a specific task. |
||
| throw new TaskKilledException(); | ||
| } | ||
| recordLength = din.readInt(); | ||
| keyPrefix = din.readLong(); | ||
| if (recordLength > arr.length) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal | |
|
|
||
| import org.apache.commons.lang3.StringUtils | ||
|
|
||
| import org.apache.spark.{Partition, SparkContext, TaskContext} | ||
| import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
|
|
@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD( | |
| rs = stmt.executeQuery() | ||
| val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) | ||
|
|
||
| CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) | ||
| CompletionIterator[InternalRow, Iterator[InternalRow]]( | ||
| new InterruptibleIterator(context, rowsIterator), close()) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose I could also have added the check into |
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to change the constructor, hence this pattern.