Skip to content

Commit d42000f

Browse files
author
Sital Kedia
committed
[SPARK-21113][CORE] Read ahead input stream to amortize disk IO cost in the Spill reader
1 parent 74a432d commit d42000f

File tree

5 files changed

+338
-13
lines changed

5 files changed

+338
-13
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
package org.apache.spark.io;
2+
3+
import com.google.common.base.Preconditions;
4+
5+
import javax.annotation.concurrent.GuardedBy;
6+
import java.io.IOException;
7+
import java.io.InputStream;
8+
import java.nio.ByteBuffer;
9+
import java.util.concurrent.ExecutorService;
10+
import java.util.concurrent.Executors;
11+
import java.util.concurrent.locks.Condition;
12+
import java.util.concurrent.locks.Lock;
13+
import java.util.concurrent.locks.ReentrantLock;
14+
15+
/**
16+
* {@link InputStream} implementation which asynchronously reads ahead from the underlying input
17+
* stream when specified amount of data has been read from the current buffer. It does it by maintaining
18+
* two buffer - active buffer and read ahead buffer. Active buffer contains data which should be returned
19+
* when a read() call is issued. The read ahead buffer is used to asynchronously read from the underlying
20+
* input stream and once the current active buffer is exhausted, we flip the two buffers so that we can
21+
* start reading from the read ahead buffer without being blocked in disk I/O.
22+
*/
23+
public class ReadAheadInputStream extends InputStream {
24+
25+
private Lock stateChangeLock = new ReentrantLock();
26+
27+
@GuardedBy("stateChangeLock")
28+
private ByteBuffer activeBuffer;
29+
30+
@GuardedBy("stateChangeLock")
31+
private ByteBuffer readAheadBuffer;
32+
33+
@GuardedBy("stateChangeLock")
34+
private boolean endOfStream;
35+
36+
@GuardedBy("stateChangeLock")
37+
// true if async read is in progress
38+
private boolean isReadInProgress;
39+
40+
@GuardedBy("stateChangeLock")
41+
// true if read is aborted due to an exception in reading from underlying input stream.
42+
private boolean isReadAborted;
43+
44+
@GuardedBy("stateChangeLock")
45+
private Exception readException;
46+
47+
// If the remaining data size in the current buffer is below this threshold,
48+
// we issue an async read from the underlying input stream.
49+
private final int readAheadThresholdInBytes;
50+
51+
private final InputStream underlyingInputStream;
52+
53+
private final ExecutorService executorService = Executors.newSingleThreadExecutor();
54+
55+
private final Condition asyncReadComplete = stateChangeLock.newCondition();
56+
57+
private final byte[] oneByte = new byte[1];
58+
59+
public ReadAheadInputStream(InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) {
60+
Preconditions.checkArgument(bufferSizeInBytes > 0, "bufferSizeInBytes should be greater than 0");
61+
Preconditions.checkArgument(readAheadThresholdInBytes > 0 && readAheadThresholdInBytes < bufferSizeInBytes,
62+
"readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes" );
63+
activeBuffer = ByteBuffer.allocate(bufferSizeInBytes);
64+
readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes);
65+
this.readAheadThresholdInBytes = readAheadThresholdInBytes;
66+
this.underlyingInputStream = inputStream;
67+
activeBuffer.flip();
68+
readAheadBuffer.flip();
69+
}
70+
71+
private boolean isEndOfStream() {
72+
if(activeBuffer.remaining() == 0 && readAheadBuffer.remaining() == 0 && endOfStream) {
73+
return true;
74+
}
75+
return false;
76+
}
77+
78+
79+
private void readAsync(final ByteBuffer byteBuffer) throws IOException {
80+
stateChangeLock.lock();
81+
if (endOfStream || isReadInProgress) {
82+
stateChangeLock.unlock();
83+
return;
84+
}
85+
byteBuffer.position(0);
86+
byteBuffer.flip();
87+
isReadInProgress = true;
88+
stateChangeLock.unlock();
89+
executorService.execute(() -> {
90+
byte[] arr;
91+
stateChangeLock.lock();
92+
arr = byteBuffer.array();
93+
stateChangeLock.unlock();
94+
// Please note that it is safe to release the lock and read into the read ahead buffer
95+
// because either of following two conditions will hold - 1. The active buffer has
96+
// data available to read so the reader will not read from the read ahead buffer.
97+
// 2. The active buffer is exhausted, in that case the reader waits for this async
98+
// read to complete. So there is no race condition in both the situations.
99+
int nRead = 0;
100+
while (nRead == 0) {
101+
try {
102+
nRead = underlyingInputStream.read(arr);
103+
if (nRead < 0) {
104+
// We hit end of the underlying input stream
105+
break;
106+
}
107+
} catch (Exception e) {
108+
stateChangeLock.lock();
109+
// We hit a read exception, which should be propagated to the reader
110+
// in the next read() call.
111+
isReadAborted = true;
112+
readException = e;
113+
stateChangeLock.unlock();
114+
}
115+
}
116+
stateChangeLock.lock();
117+
if (nRead < 0) {
118+
endOfStream = true;
119+
}
120+
else {
121+
// fill the byte buffer
122+
byteBuffer.limit(nRead);
123+
}
124+
isReadInProgress = false;
125+
signalAsyncReadComplete();
126+
stateChangeLock.unlock();
127+
});
128+
}
129+
130+
private void signalAsyncReadComplete() {
131+
stateChangeLock.lock();
132+
try {
133+
asyncReadComplete.signalAll();
134+
} finally {
135+
stateChangeLock.unlock();
136+
}
137+
}
138+
139+
private void waitForAsyncReadComplete() {
140+
stateChangeLock.lock();
141+
try {
142+
asyncReadComplete.await();
143+
} catch (InterruptedException e) {
144+
} finally {
145+
stateChangeLock.unlock();
146+
}
147+
}
148+
149+
@Override
150+
public synchronized int read() throws IOException {
151+
int val = read(oneByte, 0, 1);
152+
if (val == -1) {
153+
return -1;
154+
}
155+
return oneByte[0] & 0xFF;
156+
}
157+
158+
@Override
159+
public synchronized int read(byte[] b, int offset, int len) throws IOException {
160+
stateChangeLock.lock();
161+
try {
162+
len = readInternal(b, offset, len);
163+
}
164+
finally {
165+
stateChangeLock.unlock();
166+
}
167+
return len;
168+
}
169+
170+
/**
171+
* Internal read function which should be called only from read() api. The assumption is that
172+
* the stateChangeLock is already acquired in the caller before calling this function.
173+
*/
174+
private int readInternal(byte[] b, int offset, int len) throws IOException {
175+
176+
if (offset < 0 || len < 0 || offset + len < 0 || offset + len > b.length) {
177+
throw new IndexOutOfBoundsException();
178+
}
179+
if (!activeBuffer.hasRemaining() && !isReadInProgress) {
180+
// This condition will only be triggered for the first time read is called.
181+
readAsync(activeBuffer);
182+
waitForAsyncReadComplete();
183+
}
184+
if (!activeBuffer.hasRemaining() && isReadInProgress) {
185+
waitForAsyncReadComplete();
186+
}
187+
188+
if (isReadAborted) {
189+
throw new IOException(readException);
190+
}
191+
if (isEndOfStream()) {
192+
return -1;
193+
}
194+
len = Math.min(len, activeBuffer.remaining());
195+
activeBuffer.get(b, offset, len);
196+
197+
if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) {
198+
readAsync(readAheadBuffer);
199+
}
200+
if (!activeBuffer.hasRemaining()) {
201+
ByteBuffer temp = activeBuffer;
202+
activeBuffer = readAheadBuffer;
203+
readAheadBuffer = temp;
204+
}
205+
return len;
206+
}
207+
208+
@Override
209+
public synchronized int available() throws IOException {
210+
stateChangeLock.lock();
211+
int val = activeBuffer.remaining() + readAheadBuffer.remaining();
212+
stateChangeLock.unlock();
213+
return val;
214+
}
215+
216+
@Override
217+
public synchronized long skip(long n) throws IOException {
218+
stateChangeLock.lock();
219+
long skipped;
220+
try {
221+
skipped = skipInternal(n);
222+
} finally {
223+
stateChangeLock.unlock();
224+
}
225+
return skipped;
226+
}
227+
228+
/**
229+
* Internal skip function which should be called only from skip() api. The assumption is that
230+
* the stateChangeLock is already acquired in the caller before calling this function.
231+
*/
232+
private long skipInternal(long n) throws IOException {
233+
if (n <= 0L) {
234+
return 0L;
235+
}
236+
if (isReadInProgress) {
237+
waitForAsyncReadComplete();
238+
}
239+
if (available() >= n) {
240+
// we can skip from the internal buffers
241+
int toSkip = (int)n;
242+
byte[] temp = new byte[toSkip];
243+
while (toSkip > 0) {
244+
int skippedBytes = read(temp, 0, toSkip);
245+
toSkip -= skippedBytes;
246+
}
247+
return n;
248+
}
249+
int skippedBytes = available();
250+
long toSkip = n - skippedBytes;
251+
activeBuffer.position(0);
252+
activeBuffer.flip();
253+
readAheadBuffer.position(0);
254+
readAheadBuffer.flip();
255+
long skippedFromInputStream = underlyingInputStream.skip(toSkip);
256+
return skippedBytes + skippedFromInputStream;
257+
}
258+
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@
1717

1818
package org.apache.spark.util.collection.unsafe.sort;
1919

20-
import java.io.*;
21-
2220
import com.google.common.io.ByteStreams;
2321
import com.google.common.io.Closeables;
24-
2522
import org.apache.spark.SparkEnv;
2623
import org.apache.spark.TaskContext;
2724
import org.apache.spark.io.NioBufferedFileInputStream;
25+
import org.apache.spark.io.ReadAheadInputStream;
2826
import org.apache.spark.serializer.SerializerManager;
2927
import org.apache.spark.storage.BlockId;
3028
import org.apache.spark.unsafe.Platform;
3129
import org.slf4j.Logger;
3230
import org.slf4j.LoggerFactory;
3331

32+
import java.io.Closeable;
33+
import java.io.DataInputStream;
34+
import java.io.File;
35+
import java.io.IOException;
36+
import java.io.InputStream;
37+
3438
/**
3539
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
3640
* of the file format).
@@ -73,7 +77,9 @@ public UnsafeSorterSpillReader(
7377
}
7478

7579
final InputStream bs =
76-
new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
80+
new ReadAheadInputStream(
81+
new NioBufferedFileInputStream(file, (int) bufferSizeBytes),
82+
(int)bufferSizeBytes, (int)bufferSizeBytes / 2);
7783
try {
7884
this.in = serializerManager.wrapStream(blockId, bs);
7985
this.din = new DataInputStream(this.in);

core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java renamed to core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
/**
3232
* Tests functionality of {@link NioBufferedFileInputStream}
3333
*/
34-
public class NioBufferedFileInputStreamSuite {
34+
public class GenericFileInputStreamSuite {
3535

3636
private byte[] randomBytes;
3737

38-
private File inputFile;
38+
protected File inputFile;
39+
40+
protected InputStream inputStream;
3941

4042
@Before
4143
public void setUp() throws IOException {
@@ -52,15 +54,13 @@ public void tearDown() {
5254

5355
@Test
5456
public void testReadOneByte() throws IOException {
55-
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
5657
for (int i = 0; i < randomBytes.length; i++) {
5758
assertEquals(randomBytes[i], (byte) inputStream.read());
5859
}
5960
}
6061

6162
@Test
6263
public void testReadMultipleBytes() throws IOException {
63-
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
6464
byte[] readBytes = new byte[8 * 1024];
6565
int i = 0;
6666
while (i < randomBytes.length) {
@@ -74,7 +74,6 @@ public void testReadMultipleBytes() throws IOException {
7474

7575
@Test
7676
public void testBytesSkipped() throws IOException {
77-
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
7877
assertEquals(1024, inputStream.skip(1024));
7978
for (int i = 1024; i < randomBytes.length; i++) {
8079
assertEquals(randomBytes[i], (byte) inputStream.read());
@@ -83,7 +82,6 @@ public void testBytesSkipped() throws IOException {
8382

8483
@Test
8584
public void testBytesSkippedAfterRead() throws IOException {
86-
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
8785
for (int i = 0; i < 1024; i++) {
8886
assertEquals(randomBytes[i], (byte) inputStream.read());
8987
}
@@ -95,7 +93,6 @@ public void testBytesSkippedAfterRead() throws IOException {
9593

9694
@Test
9795
public void testNegativeBytesSkippedAfterRead() throws IOException {
98-
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
9996
for (int i = 0; i < 1024; i++) {
10097
assertEquals(randomBytes[i], (byte) inputStream.read());
10198
}
@@ -111,7 +108,6 @@ public void testNegativeBytesSkippedAfterRead() throws IOException {
111108

112109
@Test
113110
public void testSkipFromFileChannel() throws IOException {
114-
InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10);
115111
// Since the buffer is smaller than the skipped bytes, this will guarantee
116112
// we skip from underlying file channel.
117113
assertEquals(1024, inputStream.skip(1024));
@@ -128,7 +124,6 @@ public void testSkipFromFileChannel() throws IOException {
128124

129125
@Test
130126
public void testBytesSkippedAfterEOF() throws IOException {
131-
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
132127
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
133128
assertEquals(-1, inputStream.read());
134129
}

0 commit comments

Comments
 (0)