diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index e2f48e5508af..d240d0c2c8cf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -37,7 +37,7 @@ */ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); - private static final int DEFAULT_BUFFER_SIZE_BYTES = 1024 * 1024; // 1 MB + private static final int MIN_BUFFER_SIZE_BYTES = 1024 * 1024; // 1 MB private static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb private InputStream in; @@ -49,51 +49,73 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private int numRecords; private int numRecordsRemaining; - private byte[] arr = new byte[1024 * 1024]; + private byte[] arr = new byte[0]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; private final TaskContext taskContext = TaskContext.get(); + private final SerializerManager serializerManager; + private final File file; + private final BlockId blockId; public UnsafeSorterSpillReader( SerializerManager serializerManager, File file, BlockId blockId) throws IOException { assert (file.length() > 0); - long bufferSizeBytes = - SparkEnv.get() == null ? - DEFAULT_BUFFER_SIZE_BYTES: - SparkEnv.get().conf().getSizeAsBytes("spark.unsafe.sorter.spill.reader.buffer.size", - DEFAULT_BUFFER_SIZE_BYTES); - if (bufferSizeBytes > MAX_BUFFER_SIZE_BYTES || bufferSizeBytes < DEFAULT_BUFFER_SIZE_BYTES) { - // fall back to a sane default value - logger.warn("Value of config \"spark.unsafe.sorter.spill.reader.buffer.size\" = {} not in " + - "allowed range [{}, {}). Falling back to default value : {} bytes", bufferSizeBytes, - DEFAULT_BUFFER_SIZE_BYTES, MAX_BUFFER_SIZE_BYTES, DEFAULT_BUFFER_SIZE_BYTES); - bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; + this.serializerManager = serializerManager; + this.file = file; + this.blockId = blockId; + try (DataInputStream dis = new DataInputStream(new FileInputStream(this.file))) { + numRecords = numRecordsRemaining = dis.readInt(); } + } + + private InputStream getIn() throws IOException { + if (null == this.in) { + long bufferSizeBytes = + SparkEnv.get() == null ? + MIN_BUFFER_SIZE_BYTES: + SparkEnv.get().conf().getSizeAsBytes("spark.unsafe.sorter.spill.reader.buffer.size", + MIN_BUFFER_SIZE_BYTES); + if (bufferSizeBytes > MAX_BUFFER_SIZE_BYTES || bufferSizeBytes < MIN_BUFFER_SIZE_BYTES) { + // fall back to a sane default value + logger.warn("Value of config \"spark.unsafe.sorter.spill.reader.buffer.size\" = {} not in " + + "allowed range [{}, {}). Falling back to default value : {} bytes", bufferSizeBytes, + MIN_BUFFER_SIZE_BYTES, MAX_BUFFER_SIZE_BYTES, MIN_BUFFER_SIZE_BYTES); + bufferSizeBytes = MIN_BUFFER_SIZE_BYTES; + } - final double readAheadFraction = - SparkEnv.get() == null ? 0.5 : - SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - - final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); - - final InputStream bs = - new NioBufferedFileInputStream(file, (int) bufferSizeBytes); - try { - if (readAheadEnabled) { - this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), - (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); - } else { - this.in = serializerManager.wrapStream(blockId, bs); + final double readAheadFraction = + SparkEnv.get() == null ? 0.5 : + SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + + final boolean readAheadEnabled = SparkEnv.get() != null && + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + + final InputStream bs = + new NioBufferedFileInputStream(file, (int) bufferSizeBytes); + try { + if (readAheadEnabled) { + this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), + (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + } else { + this.in = serializerManager.wrapStream(blockId, bs); + } + int numRecords = new DataInputStream(this.in).readInt(); + assert numRecords == this.numRecords; + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; } - this.din = new DataInputStream(this.in); - numRecords = numRecordsRemaining = din.readInt(); - } catch (IOException e) { - Closeables.close(bs, /* swallowIOException = */ true); - throw e; } + return this.in; + } + + private DataInputStream getDin() throws IOException { + if (null == this.din) { + this.din = new DataInputStream(this.getIn()); + } + return this.din; } @Override @@ -116,13 +138,18 @@ public void loadNext() throws IOException { if (taskContext != null) { taskContext.killTaskIfInterrupted(); } - recordLength = din.readInt(); - keyPrefix = din.readLong(); - if (recordLength > arr.length) { - arr = new byte[recordLength]; + // check if the reader is closed to prevent reopen the in and din. + if (!hasNext()) { + throw new IndexOutOfBoundsException("Can not load next item when UnsafeSorterSpillReader is closed."); + } + recordLength = getDin().readInt(); + keyPrefix = getDin().readLong(); + int arrLength = Math.max(1024 * 1024, recordLength); + if (arrLength > arr.length) { + arr = new byte[arrLength]; baseObject = arr; } - ByteStreams.readFully(in, arr, 0, recordLength); + ByteStreams.readFully(getIn(), arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { close();