Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import org.apache.avro.reflect.Nullable;

import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -226,6 +228,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea
private long keyPrefix;
private int recordLength;
private long currentPageNumber;
private final TaskContext taskContext = TaskContext.get();

private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords;
Expand Down Expand Up @@ -256,6 +259,14 @@ public boolean hasNext() {

@Override
public void loadNext() {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wont this not be in the internal tight loop for reading data ?
If yes, dereferencing a volatile for each tuple processed is worrying.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have this in tight loops in the form of InterruptibleIterator wrapping all over the place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an admittedly non-scientific benchmark, I tried running

import org.apache.spark._

sc.parallelize(1 to (1000 * 1000 * 1000), 1).mapPartitions { iter =>
	val tc = TaskContext.get()
	iter.map { x =>
		tc.isInterrupted()
		x + 1
	}
}.count()

a few times with and without the tc.isInterrupted() check and there wasn't a measurable time difference in my environment. While I imagine that the volatile read could incur some higher costs in certain circumstances I think that the overhead of all of the virtual function calls and iterator interfaces, etc. will mask any gains by optimizing this read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned, this is unscientific microbenchmark :-)
Also, the isInterrupted was probably optimized away anyway ?

Since we dont need gaurantees on how soon interruption is to be honoured, batching its application (if possible) would be better in sorting ?

Yes, we do have InterruptibleIterator - unfortunately, we dont have a way to optimize that (afaik).

// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position);
currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;

import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.unsafe.Platform;
Expand All @@ -44,6 +46,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
private final TaskContext taskContext = TaskContext.get();

public UnsafeSorterSpillReader(
SerializerManager serializerManager,
Expand Down Expand Up @@ -73,6 +76,14 @@ public boolean hasNext() {

@Override
public void loadNext() throws IOException {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
}
recordLength = din.readInt();
keyPrefix = din.readLong();
if (recordLength > arr.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources

import scala.collection.mutable

import org.apache.spark.{Partition => RDDPartition, TaskContext}
import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileNameHolder, RDD}
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -88,7 +88,15 @@ class FileScanRDD(
private[this] var currentFile: PartitionedFile = null
private[this] var currentIterator: Iterator[Object] = null

def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator()
def hasNext: Boolean = {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead.
if (context.isInterrupted()) {
throw new TaskKilledException
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not move this to next() instead of hasNext (latter is not mandatory to be called - as seen here : #16252)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular iterator already won't work unless hasNext() is called since in that case nobody will call nextIterator().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks like this iterator is already broken; and we are adding to that now.

(currentIterator != null && currentIterator.hasNext) || nextIterator()
}
def next() = {
val nextElement = currentIterator.next()
// TODO: we should have a better separation of row based and batch based scan, so that we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.{Partition, SparkContext, TaskContext, TaskKilledException}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -541,6 +541,13 @@ private[jdbc] class JDBCRDD(
}

override def hasNext: Boolean = {
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead and to minimize modified code since it's not easy to
// wrap this Iterator without re-indenting tons of code.
if (context.isInterrupted()) {
throw new TaskKilledException
}
if (!finished) {
if (!gotNext) {
nextValue = getNext()
Expand Down