Skip to content
119 changes: 56 additions & 63 deletions core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

Expand Down Expand Up @@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream {
// whether there is a read ahead task running,
private boolean isReading;

// If the remaining data size in the current buffer is below this threshold,
// we issue an async read from the underlying input stream.
private final int readAheadThresholdInBytes;
// whether there is a reader waiting for data.
private AtomicBoolean isWaiting = new AtomicBoolean(false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use volatile here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it be - should compile to basically the same, and with using AtomicBoolean the intent seems more readable to me.


private final InputStream underlyingInputStream;

Expand All @@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream {
*
* @param inputStream The underlying input stream.
* @param bufferSizeInBytes The buffer size.
* @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead
* threshold, an async read is triggered.
*/
public ReadAheadInputStream(
InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) {
InputStream inputStream, int bufferSizeInBytes) {
Preconditions.checkArgument(bufferSizeInBytes > 0,
"bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes);
Preconditions.checkArgument(readAheadThresholdInBytes > 0 &&
readAheadThresholdInBytes < bufferSizeInBytes,
"readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " +
"but the value is " + readAheadThresholdInBytes);
activeBuffer = ByteBuffer.allocate(bufferSizeInBytes);
readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes);
this.readAheadThresholdInBytes = readAheadThresholdInBytes;
this.underlyingInputStream = inputStream;
activeBuffer.flip();
readAheadBuffer.flip();
Expand Down Expand Up @@ -166,12 +159,17 @@ public void run() {
// in that case the reader waits for this async read to complete.
// So there is no race condition in both the situations.
int read = 0;
int off = 0, len = arr.length;
Throwable exception = null;
try {
while (true) {
read = underlyingInputStream.read(arr);
if (0 != read) break;
}
// try to fill the read ahead buffer.
// if a reader is waiting, possibly return early.
do {
read = underlyingInputStream.read(arr, off, len);
if (read <= 0) break;
off += read;
len -= read;
} while (len > 0 && !isWaiting.get());
} catch (Throwable ex) {
exception = ex;
if (ex instanceof Error) {
Expand All @@ -181,13 +179,12 @@ public void run() {
}
} finally {
stateChangeLock.lock();
readAheadBuffer.limit(off);
if (read < 0 || (exception instanceof EOFException)) {
endOfStream = true;
} else if (exception != null) {
readAborted = true;
readException = exception;
} else {
readAheadBuffer.limit(read);
}
readInProgress = false;
signalAsyncReadComplete();
Expand Down Expand Up @@ -230,7 +227,10 @@ private void signalAsyncReadComplete() {

private void waitForAsyncReadComplete() throws IOException {
stateChangeLock.lock();
isWaiting.set(true);
try {
// There is only one reader, and one writer, so the writer should signal only once,
// but a while loop checking the wake up condition is still needed to avoid spurious wakeups.
while (readInProgress) {
asyncReadComplete.await();
}
Expand All @@ -239,15 +239,21 @@ private void waitForAsyncReadComplete() throws IOException {
iio.initCause(e);
throw iio;
} finally {
isWaiting.set(false);
stateChangeLock.unlock();
}
checkReadException();
}

@Override
public int read() throws IOException {
byte[] oneByteArray = oneByte.get();
return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF;
if (activeBuffer.hasRemaining()) {
// short path - just get one byte.
return activeBuffer.get() & 0xFF;
} else {
byte[] oneByteArray = oneByte.get();
return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF;
}
}

@Override
Expand All @@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException {
if (len == 0) {
return 0;
}
stateChangeLock.lock();
try {
return readInternal(b, offset, len);
} finally {
stateChangeLock.unlock();
}
}

/**
* flip the active and read ahead buffer
*/
private void swapBuffers() {
ByteBuffer temp = activeBuffer;
activeBuffer = readAheadBuffer;
readAheadBuffer = temp;
}

/**
* Internal read function which should be called only from read() api. The assumption is that
* the stateChangeLock is already acquired in the caller before calling this function.
*/
private int readInternal(byte[] b, int offset, int len) throws IOException {
assert (stateChangeLock.isLocked());
if (!activeBuffer.hasRemaining()) {
waitForAsyncReadComplete();
if (readAheadBuffer.hasRemaining()) {
swapBuffers();
} else {
// The first read or activeBuffer is skipped.
readAsync();
// No remaining in active buffer - lock and switch to write ahead buffer.
stateChangeLock.lock();
try {
waitForAsyncReadComplete();
if (isEndOfStream()) {
return -1;
if (!readAheadBuffer.hasRemaining()) {
// The first read.
readAsync();
waitForAsyncReadComplete();
if (isEndOfStream()) {
return -1;
}
}
// Swap the newly read read ahead buffer in place of empty active buffer.
Copy link
Member

@kiszk kiszk Feb 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it good to use read-ahead instead of read ahead in comments for ease of reading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other existing places in comments in the file use read ahead.

swapBuffers();
// After swapping buffers, trigger another async read for read ahead buffer.
readAsync();
} finally {
stateChangeLock.unlock();
}
} else {
checkReadException();
}
len = Math.min(len, activeBuffer.remaining());
activeBuffer.get(b, offset, len);

if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) {
readAsync();
}
return len;
}

/**
* flip the active and read ahead buffer
*/
private void swapBuffers() {
ByteBuffer temp = activeBuffer;
activeBuffer = readAheadBuffer;
readAheadBuffer = temp;
}

@Override
public int available() throws IOException {
stateChangeLock.lock();
Expand All @@ -323,6 +318,11 @@ public long skip(long n) throws IOException {
if (n <= 0L) {
return 0L;
}
if (n <= activeBuffer.remaining()) {
// Only skipping from active buffer is sufficient
activeBuffer.position((int) n + activeBuffer.position());
return n;
}
stateChangeLock.lock();
long skipped;
try {
Expand All @@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException {
if (available() >= n) {
// we can skip from the internal buffers
int toSkip = (int) n;
if (toSkip <= activeBuffer.remaining()) {
// Only skipping from active buffer is sufficient
activeBuffer.position(toSkip + activeBuffer.position());
if (activeBuffer.remaining() <= readAheadThresholdInBytes
&& !readAheadBuffer.hasRemaining()) {
readAsync();
}
return n;
}
// We need to skip from both active buffer and read ahead buffer
toSkip -= activeBuffer.remaining();
assert(toSkip > 0); // skipping from activeBuffer already handled.
activeBuffer.position(0);
activeBuffer.flip();
readAheadBuffer.position(toSkip + readAheadBuffer.position());
swapBuffers();
// Trigger async read to emptied read ahead buffer.
readAsync();
return n;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,15 @@ public UnsafeSorterSpillReader(
bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
}

final double readAheadFraction =
SparkEnv.get() == null ? 0.5 :
SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5);

// SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf
// regression for TPC-DS queries.
final boolean readAheadEnabled = SparkEnv.get() != null &&
SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false);
SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true);

final InputStream bs =
new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
try {
if (readAheadEnabled) {
this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs),
(int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction));
(int) bufferSizeBytes);
} else {
this.in = serializerManager.wrapStream(blockId, bs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite {

protected File inputFile;

protected InputStream inputStream;
protected InputStream[] inputStreams;

@Before
public void setUp() throws IOException {
Expand All @@ -54,77 +54,91 @@ public void tearDown() {

@Test
public void testReadOneByte() throws IOException {
for (int i = 0; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
for (InputStream inputStream: inputStreams) {
for (int i = 0; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
}

@Test
public void testReadMultipleBytes() throws IOException {
byte[] readBytes = new byte[8 * 1024];
int i = 0;
while (i < randomBytes.length) {
int read = inputStream.read(readBytes, 0, 8 * 1024);
for (int j = 0; j < read; j++) {
assertEquals(randomBytes[i], readBytes[j]);
i++;
for (InputStream inputStream: inputStreams) {
byte[] readBytes = new byte[8 * 1024];
int i = 0;
while (i < randomBytes.length) {
int read = inputStream.read(readBytes, 0, 8 * 1024);
for (int j = 0; j < read; j++) {
assertEquals(randomBytes[i], readBytes[j]);
i++;
}
}
}
}

@Test
public void testBytesSkipped() throws IOException {
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
for (InputStream inputStream: inputStreams) {
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
}

@Test
public void testBytesSkippedAfterRead() throws IOException {
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
for (InputStream inputStream: inputStreams) {
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
}

@Test
public void testNegativeBytesSkippedAfterRead() throws IOException {
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
// Skipping negative bytes should essential be a no-op
assertEquals(0, inputStream.skip(-1));
assertEquals(0, inputStream.skip(-1024));
assertEquals(0, inputStream.skip(Long.MIN_VALUE));
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
for (InputStream inputStream: inputStreams) {
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
// Skipping negative bytes should essential be a no-op
assertEquals(0, inputStream.skip(-1));
assertEquals(0, inputStream.skip(-1024));
assertEquals(0, inputStream.skip(Long.MIN_VALUE));
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
}

@Test
public void testSkipFromFileChannel() throws IOException {
// Since the buffer is smaller than the skipped bytes, this will guarantee
// we skip from underlying file channel.
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < 2048; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(256, inputStream.skip(256));
assertEquals(256, inputStream.skip(256));
assertEquals(512, inputStream.skip(512));
for (int i = 3072; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
for (InputStream inputStream: inputStreams) {
// Since the buffer is smaller than the skipped bytes, this will guarantee
// we skip from underlying file channel.
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < 2048; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(256, inputStream.skip(256));
assertEquals(256, inputStream.skip(256));
assertEquals(512, inputStream.skip(512));
for (int i = 3072; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
}

@Test
public void testBytesSkippedAfterEOF() throws IOException {
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
assertEquals(-1, inputStream.read());
for (InputStream inputStream: inputStreams) {
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
assertEquals(-1, inputStream.read());
}
}
}
Loading