diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e2059cec132d2..cf5f05a7010e0 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -623,6 +623,7 @@ public UnsafeSorterIterator getIterator(int startIndex) throws IOException { return iter; } else { LinkedList queue = new LinkedList<>(); + logger.debug("number of spillWriters: {}", spillWriters.size()); int i = 0; for (UnsafeSorterSpillWriter spillWriter : spillWriters) { if (i + spillWriter.recordsSpilled() > startIndex) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 9521ab86a12d5..4df75de9fc637 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -46,7 +46,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen // Variables that change with every record read: private int recordLength; private long keyPrefix; - private int numRecords; + private final int numRecords; private int numRecordsRemaining; private byte[] arr = new byte[1024 * 1024]; @@ -54,6 +54,11 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; private final TaskContext taskContext = TaskContext.get(); + private final long buffSize; + private final File file; + private final BlockId blockId; + private final SerializerManager serializerManager; + public UnsafeSorterSpillReader( SerializerManager serializerManager, File file, @@ -72,12 +77,27 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } + // No need to hold the file open until records need to be loaded. + // This is to prevent too many files open issue partially. + try (InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); + DataInputStream dataIn = new DataInputStream(serializerManager.wrapStream(blockId, bs))) { + this.numRecords = dataIn.readInt(); + this.numRecordsRemaining = numRecords; + } + + this.buffSize = bufferSizeBytes; + this.file = file; + this.blockId = blockId; + this.serializerManager = serializerManager; + } + + private void initStreams() throws IOException { final InputStream bs = - new NioBufferedFileInputStream(file, (int) bufferSizeBytes); + new NioBufferedFileInputStream(file, (int) buffSize); try { this.in = serializerManager.wrapStream(blockId, bs); this.din = new DataInputStream(this.in); - numRecords = numRecordsRemaining = din.readInt(); + this.numRecordsRemaining = din.readInt(); } catch (IOException e) { Closeables.close(bs, /* swallowIOException = */ true); throw e; @@ -104,6 +124,12 @@ public void loadNext() throws IOException { if (taskContext != null) { taskContext.killTaskIfInterrupted(); } + if (this.din == null) { + // It is time to initialize and hold the input stream of the spill file + // for loading records. Keeping the input stream open too early will very possibly + // encounter too many file open issue. + initStreams(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) {