Skip to content

Commit 7539ae5

Browse files
juliuszsompolskicloud-fan
authored andcommitted
[SPARK-23366] Improve hot reading path in ReadAheadInputStream
## What changes were proposed in this pull request? `ReadAheadInputStream` was introduced in #18317 to optimize reading spill files from disk. However, from the profiles it seems that the hot path of reading small amounts of data (like readInt) is inefficient - it involves taking locks, and multiple checks. Optimize locking: Lock is not needed when simply accessing the active buffer. Only lock when needing to swap buffers or trigger async reading, or get information about the async state. Optimize short-path single byte reads, that are used e.g. by Java library DataInputStream.readInt. The asyncReader used to call "read" only once on the underlying stream, that never filled the underlying buffer when it was wrapping an LZ4BlockInputStream. If the buffer was returned unfilled, that would trigger the async reader to be triggered to fill the read ahead buffer on each call, because the reader would see that the active buffer is below the refill threshold all the time. However, filling the full buffer all the time could introduce increased latency, so also add an `AtomicBoolean` flag for the async reader to return earlier if there is a reader waiting for data. Remove `readAheadThresholdInBytes` and instead immediately trigger async read when switching the buffers. It allows to simplify code paths, especially the hot one that then only has to check if there is available data in the active buffer, without worrying if it needs to retrigger async read. It seems to have positive effect on perf. ## How was this patch tested? It was noticed as a regression in some workloads after upgrading to Spark 2.3.  It was particularly visible on TPCDS Q95 running on instances with fast disk (i3 AWS instances). Running with profiling: * Spark 2.2 - 5.2-5.3 minutes 9.5% in LZ4BlockInputStream.read * Spark 2.3 - 6.4-6.6 minutes 31.1% in ReadAheadInputStream.read * Spark 2.3 + fix - 5.3-5.4 minutes 13.3% in ReadAheadInputStream.read - very slightly slower, practically within noise. We didn't see other regressions, and many workloads in general seem to be faster with Spark 2.3 (not investigated if thanks to async readed, or unrelated). Author: Juliusz Sompolski <[email protected]> Closes #20555 from juliuszsompolski/SPARK-23366.
1 parent f38c760 commit 7539ae5

File tree

5 files changed

+133
-117
lines changed

5 files changed

+133
-117
lines changed

core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java

Lines changed: 56 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.nio.ByteBuffer;
2828
import java.util.concurrent.ExecutorService;
2929
import java.util.concurrent.TimeUnit;
30+
import java.util.concurrent.atomic.AtomicBoolean;
3031
import java.util.concurrent.locks.Condition;
3132
import java.util.concurrent.locks.ReentrantLock;
3233

@@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream {
7879
// whether there is a read ahead task running,
7980
private boolean isReading;
8081

81-
// If the remaining data size in the current buffer is below this threshold,
82-
// we issue an async read from the underlying input stream.
83-
private final int readAheadThresholdInBytes;
82+
// whether there is a reader waiting for data.
83+
private AtomicBoolean isWaiting = new AtomicBoolean(false);
8484

8585
private final InputStream underlyingInputStream;
8686

@@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream {
9797
*
9898
* @param inputStream The underlying input stream.
9999
* @param bufferSizeInBytes The buffer size.
100-
* @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead
101-
* threshold, an async read is triggered.
102100
*/
103101
public ReadAheadInputStream(
104-
InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) {
102+
InputStream inputStream, int bufferSizeInBytes) {
105103
Preconditions.checkArgument(bufferSizeInBytes > 0,
106104
"bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes);
107-
Preconditions.checkArgument(readAheadThresholdInBytes > 0 &&
108-
readAheadThresholdInBytes < bufferSizeInBytes,
109-
"readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " +
110-
"but the value is " + readAheadThresholdInBytes);
111105
activeBuffer = ByteBuffer.allocate(bufferSizeInBytes);
112106
readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes);
113-
this.readAheadThresholdInBytes = readAheadThresholdInBytes;
114107
this.underlyingInputStream = inputStream;
115108
activeBuffer.flip();
116109
readAheadBuffer.flip();
@@ -166,12 +159,17 @@ public void run() {
166159
// in that case the reader waits for this async read to complete.
167160
// So there is no race condition in both the situations.
168161
int read = 0;
162+
int off = 0, len = arr.length;
169163
Throwable exception = null;
170164
try {
171-
while (true) {
172-
read = underlyingInputStream.read(arr);
173-
if (0 != read) break;
174-
}
165+
// try to fill the read ahead buffer.
166+
// if a reader is waiting, possibly return early.
167+
do {
168+
read = underlyingInputStream.read(arr, off, len);
169+
if (read <= 0) break;
170+
off += read;
171+
len -= read;
172+
} while (len > 0 && !isWaiting.get());
175173
} catch (Throwable ex) {
176174
exception = ex;
177175
if (ex instanceof Error) {
@@ -181,13 +179,12 @@ public void run() {
181179
}
182180
} finally {
183181
stateChangeLock.lock();
182+
readAheadBuffer.limit(off);
184183
if (read < 0 || (exception instanceof EOFException)) {
185184
endOfStream = true;
186185
} else if (exception != null) {
187186
readAborted = true;
188187
readException = exception;
189-
} else {
190-
readAheadBuffer.limit(read);
191188
}
192189
readInProgress = false;
193190
signalAsyncReadComplete();
@@ -230,7 +227,10 @@ private void signalAsyncReadComplete() {
230227

231228
private void waitForAsyncReadComplete() throws IOException {
232229
stateChangeLock.lock();
230+
isWaiting.set(true);
233231
try {
232+
// There is only one reader, and one writer, so the writer should signal only once,
233+
// but a while loop checking the wake up condition is still needed to avoid spurious wakeups.
234234
while (readInProgress) {
235235
asyncReadComplete.await();
236236
}
@@ -239,15 +239,21 @@ private void waitForAsyncReadComplete() throws IOException {
239239
iio.initCause(e);
240240
throw iio;
241241
} finally {
242+
isWaiting.set(false);
242243
stateChangeLock.unlock();
243244
}
244245
checkReadException();
245246
}
246247

247248
@Override
248249
public int read() throws IOException {
249-
byte[] oneByteArray = oneByte.get();
250-
return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF;
250+
if (activeBuffer.hasRemaining()) {
251+
// short path - just get one byte.
252+
return activeBuffer.get() & 0xFF;
253+
} else {
254+
byte[] oneByteArray = oneByte.get();
255+
return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF;
256+
}
251257
}
252258

253259
@Override
@@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException {
258264
if (len == 0) {
259265
return 0;
260266
}
261-
stateChangeLock.lock();
262-
try {
263-
return readInternal(b, offset, len);
264-
} finally {
265-
stateChangeLock.unlock();
266-
}
267-
}
268267

269-
/**
270-
* flip the active and read ahead buffer
271-
*/
272-
private void swapBuffers() {
273-
ByteBuffer temp = activeBuffer;
274-
activeBuffer = readAheadBuffer;
275-
readAheadBuffer = temp;
276-
}
277-
278-
/**
279-
* Internal read function which should be called only from read() api. The assumption is that
280-
* the stateChangeLock is already acquired in the caller before calling this function.
281-
*/
282-
private int readInternal(byte[] b, int offset, int len) throws IOException {
283-
assert (stateChangeLock.isLocked());
284268
if (!activeBuffer.hasRemaining()) {
285-
waitForAsyncReadComplete();
286-
if (readAheadBuffer.hasRemaining()) {
287-
swapBuffers();
288-
} else {
289-
// The first read or activeBuffer is skipped.
290-
readAsync();
269+
// No remaining in active buffer - lock and switch to write ahead buffer.
270+
stateChangeLock.lock();
271+
try {
291272
waitForAsyncReadComplete();
292-
if (isEndOfStream()) {
293-
return -1;
273+
if (!readAheadBuffer.hasRemaining()) {
274+
// The first read.
275+
readAsync();
276+
waitForAsyncReadComplete();
277+
if (isEndOfStream()) {
278+
return -1;
279+
}
294280
}
281+
// Swap the newly read read ahead buffer in place of empty active buffer.
295282
swapBuffers();
283+
// After swapping buffers, trigger another async read for read ahead buffer.
284+
readAsync();
285+
} finally {
286+
stateChangeLock.unlock();
296287
}
297-
} else {
298-
checkReadException();
299288
}
300289
len = Math.min(len, activeBuffer.remaining());
301290
activeBuffer.get(b, offset, len);
302291

303-
if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) {
304-
readAsync();
305-
}
306292
return len;
307293
}
308294

295+
/**
296+
* flip the active and read ahead buffer
297+
*/
298+
private void swapBuffers() {
299+
ByteBuffer temp = activeBuffer;
300+
activeBuffer = readAheadBuffer;
301+
readAheadBuffer = temp;
302+
}
303+
309304
@Override
310305
public int available() throws IOException {
311306
stateChangeLock.lock();
@@ -323,6 +318,11 @@ public long skip(long n) throws IOException {
323318
if (n <= 0L) {
324319
return 0L;
325320
}
321+
if (n <= activeBuffer.remaining()) {
322+
// Only skipping from active buffer is sufficient
323+
activeBuffer.position((int) n + activeBuffer.position());
324+
return n;
325+
}
326326
stateChangeLock.lock();
327327
long skipped;
328328
try {
@@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException {
346346
if (available() >= n) {
347347
// we can skip from the internal buffers
348348
int toSkip = (int) n;
349-
if (toSkip <= activeBuffer.remaining()) {
350-
// Only skipping from active buffer is sufficient
351-
activeBuffer.position(toSkip + activeBuffer.position());
352-
if (activeBuffer.remaining() <= readAheadThresholdInBytes
353-
&& !readAheadBuffer.hasRemaining()) {
354-
readAsync();
355-
}
356-
return n;
357-
}
358349
// We need to skip from both active buffer and read ahead buffer
359350
toSkip -= activeBuffer.remaining();
351+
assert(toSkip > 0); // skipping from activeBuffer already handled.
360352
activeBuffer.position(0);
361353
activeBuffer.flip();
362354
readAheadBuffer.position(toSkip + readAheadBuffer.position());
363355
swapBuffers();
356+
// Trigger async read to emptied read ahead buffer.
364357
readAsync();
365358
return n;
366359
} else {

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,15 @@ public UnsafeSorterSpillReader(
7272
bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
7373
}
7474

75-
final double readAheadFraction =
76-
SparkEnv.get() == null ? 0.5 :
77-
SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5);
78-
79-
// SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf
80-
// regression for TPC-DS queries.
8175
final boolean readAheadEnabled = SparkEnv.get() != null &&
82-
SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false);
76+
SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true);
8377

8478
final InputStream bs =
8579
new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
8680
try {
8781
if (readAheadEnabled) {
8882
this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs),
89-
(int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction));
83+
(int) bufferSizeBytes);
9084
} else {
9185
this.in = serializerManager.wrapStream(blockId, bs);
9286
}

core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite {
3737

3838
protected File inputFile;
3939

40-
protected InputStream inputStream;
40+
protected InputStream[] inputStreams;
4141

4242
@Before
4343
public void setUp() throws IOException {
@@ -54,77 +54,91 @@ public void tearDown() {
5454

5555
@Test
5656
public void testReadOneByte() throws IOException {
57-
for (int i = 0; i < randomBytes.length; i++) {
58-
assertEquals(randomBytes[i], (byte) inputStream.read());
57+
for (InputStream inputStream: inputStreams) {
58+
for (int i = 0; i < randomBytes.length; i++) {
59+
assertEquals(randomBytes[i], (byte) inputStream.read());
60+
}
5961
}
6062
}
6163

6264
@Test
6365
public void testReadMultipleBytes() throws IOException {
64-
byte[] readBytes = new byte[8 * 1024];
65-
int i = 0;
66-
while (i < randomBytes.length) {
67-
int read = inputStream.read(readBytes, 0, 8 * 1024);
68-
for (int j = 0; j < read; j++) {
69-
assertEquals(randomBytes[i], readBytes[j]);
70-
i++;
66+
for (InputStream inputStream: inputStreams) {
67+
byte[] readBytes = new byte[8 * 1024];
68+
int i = 0;
69+
while (i < randomBytes.length) {
70+
int read = inputStream.read(readBytes, 0, 8 * 1024);
71+
for (int j = 0; j < read; j++) {
72+
assertEquals(randomBytes[i], readBytes[j]);
73+
i++;
74+
}
7175
}
7276
}
7377
}
7478

7579
@Test
7680
public void testBytesSkipped() throws IOException {
77-
assertEquals(1024, inputStream.skip(1024));
78-
for (int i = 1024; i < randomBytes.length; i++) {
79-
assertEquals(randomBytes[i], (byte) inputStream.read());
81+
for (InputStream inputStream: inputStreams) {
82+
assertEquals(1024, inputStream.skip(1024));
83+
for (int i = 1024; i < randomBytes.length; i++) {
84+
assertEquals(randomBytes[i], (byte) inputStream.read());
85+
}
8086
}
8187
}
8288

8389
@Test
8490
public void testBytesSkippedAfterRead() throws IOException {
85-
for (int i = 0; i < 1024; i++) {
86-
assertEquals(randomBytes[i], (byte) inputStream.read());
87-
}
88-
assertEquals(1024, inputStream.skip(1024));
89-
for (int i = 2048; i < randomBytes.length; i++) {
90-
assertEquals(randomBytes[i], (byte) inputStream.read());
91+
for (InputStream inputStream: inputStreams) {
92+
for (int i = 0; i < 1024; i++) {
93+
assertEquals(randomBytes[i], (byte) inputStream.read());
94+
}
95+
assertEquals(1024, inputStream.skip(1024));
96+
for (int i = 2048; i < randomBytes.length; i++) {
97+
assertEquals(randomBytes[i], (byte) inputStream.read());
98+
}
9199
}
92100
}
93101

94102
@Test
95103
public void testNegativeBytesSkippedAfterRead() throws IOException {
96-
for (int i = 0; i < 1024; i++) {
97-
assertEquals(randomBytes[i], (byte) inputStream.read());
98-
}
99-
// Skipping negative bytes should essential be a no-op
100-
assertEquals(0, inputStream.skip(-1));
101-
assertEquals(0, inputStream.skip(-1024));
102-
assertEquals(0, inputStream.skip(Long.MIN_VALUE));
103-
assertEquals(1024, inputStream.skip(1024));
104-
for (int i = 2048; i < randomBytes.length; i++) {
105-
assertEquals(randomBytes[i], (byte) inputStream.read());
104+
for (InputStream inputStream: inputStreams) {
105+
for (int i = 0; i < 1024; i++) {
106+
assertEquals(randomBytes[i], (byte) inputStream.read());
107+
}
108+
// Skipping negative bytes should essential be a no-op
109+
assertEquals(0, inputStream.skip(-1));
110+
assertEquals(0, inputStream.skip(-1024));
111+
assertEquals(0, inputStream.skip(Long.MIN_VALUE));
112+
assertEquals(1024, inputStream.skip(1024));
113+
for (int i = 2048; i < randomBytes.length; i++) {
114+
assertEquals(randomBytes[i], (byte) inputStream.read());
115+
}
106116
}
107117
}
108118

109119
@Test
110120
public void testSkipFromFileChannel() throws IOException {
111-
// Since the buffer is smaller than the skipped bytes, this will guarantee
112-
// we skip from underlying file channel.
113-
assertEquals(1024, inputStream.skip(1024));
114-
for (int i = 1024; i < 2048; i++) {
115-
assertEquals(randomBytes[i], (byte) inputStream.read());
116-
}
117-
assertEquals(256, inputStream.skip(256));
118-
assertEquals(256, inputStream.skip(256));
119-
assertEquals(512, inputStream.skip(512));
120-
for (int i = 3072; i < randomBytes.length; i++) {
121-
assertEquals(randomBytes[i], (byte) inputStream.read());
121+
for (InputStream inputStream: inputStreams) {
122+
// Since the buffer is smaller than the skipped bytes, this will guarantee
123+
// we skip from underlying file channel.
124+
assertEquals(1024, inputStream.skip(1024));
125+
for (int i = 1024; i < 2048; i++) {
126+
assertEquals(randomBytes[i], (byte) inputStream.read());
127+
}
128+
assertEquals(256, inputStream.skip(256));
129+
assertEquals(256, inputStream.skip(256));
130+
assertEquals(512, inputStream.skip(512));
131+
for (int i = 3072; i < randomBytes.length; i++) {
132+
assertEquals(randomBytes[i], (byte) inputStream.read());
133+
}
122134
}
123135
}
124136

125137
@Test
126138
public void testBytesSkippedAfterEOF() throws IOException {
127-
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
128-
assertEquals(-1, inputStream.read());
139+
for (InputStream inputStream: inputStreams) {
140+
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
141+
assertEquals(-1, inputStream.read());
142+
}
129143
}
130144
}

0 commit comments

Comments
 (0)