diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
index 70c112b78911d..804119cd06fa6 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle.api;
import java.io.IOException;
+import java.util.Optional;
import org.apache.spark.annotation.Private;
@@ -39,17 +40,39 @@ public interface ShuffleExecutorComponents {
/**
* Called once per map task to create a writer that will be responsible for persisting all the
* partitioned bytes written by that map task.
- * @param shuffleId Unique identifier for the shuffle the map task is a part of
+ *
+ * @param shuffleId Unique identifier for the shuffle the map task is a part of
* @param mapId Within the shuffle, the identifier of the map task
* @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
- * with the same (shuffleId, mapId) pair can be distinguished by the
- * different values of mapTaskAttemptId.
+ * with the same (shuffleId, mapId) pair can be distinguished by the
+ * different values of mapTaskAttemptId.
* @param numPartitions The number of partitions that will be written by the map task. Some of
-* these partitions may be empty.
+ * these partitions may be empty.
*/
ShuffleMapOutputWriter createMapOutputWriter(
int shuffleId,
int mapId,
long mapTaskAttemptId,
int numPartitions) throws IOException;
+
+ /**
+ * An optional extension for creating a map output writer that can optimize the transfer of a
+ * single partition file, as the entire result of a map task, to the backing store.
+ *
+ * Most implementations should return the default {@link Optional#empty()} to indicate that
+ * they do not support this optimization. This primarily is for backwards-compatibility in
+ * preserving an optimization in the local disk shuffle storage implementation.
+ *
+ * @param shuffleId Unique identifier for the shuffle the map task is a part of
+ * @param mapId Within the shuffle, the identifier of the map task
+ * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
+ * with the same (shuffleId, mapId) pair can be distinguished by the
+ * different values of mapTaskAttemptId.
+ */
+ default Optional createSingleFileMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ long mapTaskAttemptId) throws IOException {
+ return Optional.empty();
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
new file mode 100644
index 0000000000000..cad8dcfda52bc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * Optional extension for partition writing that is optimized for transferring a single
+ * file to the backing store.
+ */
+@Private
+public interface SingleSpillShuffleMapOutputWriter {
+
+ /**
+ * Transfer a file that contains the bytes of all the partitions written by this map task.
+ */
+ void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 9d05f03613ce9..f59bddc993639 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -17,9 +17,12 @@
package org.apache.spark.shuffle.sort;
+import java.nio.channels.Channels;
+import java.util.Optional;
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
import java.util.Iterator;
import scala.Option;
@@ -31,7 +34,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
-import com.google.common.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,8 +43,6 @@
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
-import org.apache.commons.io.output.CloseShieldOutputStream;
-import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
@@ -50,10 +50,13 @@
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;
@@ -65,15 +68,14 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
@VisibleForTesting
- static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
private final BlockManager blockManager;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetricsReporter writeMetrics;
+ private final ShuffleExecutorComponents shuffleExecutorComponents;
private final int shuffleId;
private final int mapId;
private final TaskContext taskContext;
@@ -81,7 +83,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private final boolean transferToEnabled;
private final int initialSortBufferSize;
private final int inputBufferSizeInBytes;
- private final int outputBufferSizeInBytes;
@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
@@ -103,27 +104,15 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream
*/
private boolean stopping = false;
- private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream {
-
- CloseAndFlushShieldOutputStream(OutputStream outputStream) {
- super(outputStream);
- }
-
- @Override
- public void flush() {
- // do nothing
- }
- }
-
public UnsafeShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
SerializedShuffleHandle handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf,
- ShuffleWriteMetricsReporter writeMetrics) throws IOException {
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleExecutorComponents shuffleExecutorComponents) {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@@ -132,7 +121,6 @@ public UnsafeShuffleWriter(
" reduce partitions");
}
this.blockManager = blockManager;
- this.shuffleBlockResolver = shuffleBlockResolver;
this.memoryManager = memoryManager;
this.mapId = mapId;
final ShuffleDependency dep = handle.dependency();
@@ -140,6 +128,7 @@ public UnsafeShuffleWriter(
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = writeMetrics;
+ this.shuffleExecutorComponents = shuffleExecutorComponents;
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
@@ -147,8 +136,6 @@ public UnsafeShuffleWriter(
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.inputBufferSizeInBytes =
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
- this.outputBufferSizeInBytes =
- (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
open();
}
@@ -231,22 +218,13 @@ void closeAndWriteOutput() throws IOException {
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
- final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- final File tmp = Utils.tempFileWith(output);
try {
- try {
- partitionLengths = mergeSpills(spills, tmp);
- } finally {
- for (SpillInfo spill : spills) {
- if (spill.file.exists() && ! spill.file.delete()) {
- logger.error("Error while deleting spill file {}", spill.file.getPath());
- }
- }
- }
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ partitionLengths = mergeSpills(spills);
} finally {
- if (tmp.exists() && !tmp.delete()) {
- logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && !spill.file.delete()) {
+ logger.error("Error while deleting spill file {}", spill.file.getPath());
+ }
}
}
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
@@ -281,137 +259,161 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
+ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+ long[] partitionLengths;
+ if (spills.length == 0) {
+ final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
+ .createMapOutputWriter(
+ shuffleId,
+ mapId,
+ taskContext.taskAttemptId(),
+ partitioner.numPartitions());
+ return mapWriter.commitAllPartitions();
+ } else if (spills.length == 1) {
+ Optional maybeSingleFileWriter =
+ shuffleExecutorComponents.createSingleFileMapOutputWriter(
+ shuffleId, mapId, taskContext.taskAttemptId());
+ if (maybeSingleFileWriter.isPresent()) {
+ // Here, we don't need to perform any metrics updates because the bytes written to this
+ // output file would have already been counted as shuffle bytes written.
+ partitionLengths = spills[0].partitionLengths;
+ maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
+ } else {
+ partitionLengths = mergeSpillsUsingStandardWriter(spills);
+ }
+ } else {
+ partitionLengths = mergeSpillsUsingStandardWriter(spills);
+ }
+ return partitionLengths;
+ }
+
+ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException {
+ long[] partitionLengths;
final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
- (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE());
+ (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE());
final boolean fastMergeIsSupported = !compressionEnabled ||
- CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
+ CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
+ final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
+ .createMapOutputWriter(
+ shuffleId,
+ mapId,
+ taskContext.taskAttemptId(),
+ partitioner.numPartitions());
try {
- if (spills.length == 0) {
- new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
- } else if (spills.length == 1) {
- // Here, we don't need to perform any metrics updates because the bytes written to this
- // output file would have already been counted as shuffle bytes written.
- Files.move(spills[0].file, outputFile);
- return spills[0].partitionLengths;
- } else {
- final long[] partitionLengths;
- // There are multiple spills to merge, so none of these spill files' lengths were counted
- // towards our shuffle write count or shuffle write time. If we use the slow merge path,
- // then the final output file's size won't necessarily be equal to the sum of the spill
- // files' sizes. To guard against this case, we look at the output file's actual size when
- // computing shuffle bytes written.
- //
- // We allow the individual merge methods to report their own IO times since different merge
- // strategies use different IO techniques. We count IO during merge towards the shuffle
- // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
- // branch in ExternalSorter.
- if (fastMergeEnabled && fastMergeIsSupported) {
- // Compression is disabled or we are using an IO compression codec that supports
- // decompression of concatenated compressed streams, so we can perform a fast spill merge
- // that doesn't need to interpret the spilled bytes.
- if (transferToEnabled && !encryptionEnabled) {
- logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
- } else {
- logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
- }
+ // There are multiple spills to merge, so none of these spill files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+ // then the final output file's size won't necessarily be equal to the sum of the spill
+ // files' sizes. To guard against this case, we look at the output file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO times since different merge
+ // strategies use different IO techniques. We count IO during merge towards the shuffle
+ // write time, which appears to be consistent with the "not bypassing merge-sort" branch in
+ // ExternalSorter.
+ if (fastMergeEnabled && fastMergeIsSupported) {
+ // Compression is disabled or we are using an IO compression codec that supports
+ // decompression of concatenated compressed streams, so we can perform a fast spill merge
+ // that doesn't need to interpret the spilled bytes.
+ if (transferToEnabled && !encryptionEnabled) {
+ logger.debug("Using transferTo-based fast merge");
+ mergeSpillsWithTransferTo(spills, mapWriter);
} else {
- logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ logger.debug("Using fileStream-based fast merge");
+ mergeSpillsWithFileStream(spills, mapWriter, null);
}
- // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
- // in-memory records, we write out the in-memory records to a file but do not count that
- // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
- // to be counted as shuffle write, but this will lead to double-counting of the final
- // SpillInfo's bytes.
- writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- writeMetrics.incBytesWritten(outputFile.length());
- return partitionLengths;
+ } else {
+ logger.debug("Using slow merge");
+ mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
}
- } catch (IOException e) {
- if (outputFile.exists() && !outputFile.delete()) {
- logger.error("Unable to delete output file {}", outputFile.getPath());
+ // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+ // in-memory records, we write out the in-memory records to a file but do not count that
+ // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+ // to be counted as shuffle write, but this will lead to double-counting of the final
+ // SpillInfo's bytes.
+ writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
+ partitionLengths = mapWriter.commitAllPartitions();
+ } catch (Exception e) {
+ try {
+ mapWriter.abort(e);
+ } catch (Exception e2) {
+ logger.warn("Failed to abort writing the map output.", e2);
+ e.addSuppressed(e2);
}
throw e;
}
+ return partitionLengths;
}
/**
* Merges spill files using Java FileStreams. This code path is typically slower than
* the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[],
- * File)}, and it's mostly used in cases where the IO compression codec does not support
- * concatenation of compressed data, when encryption is enabled, or when users have
- * explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
+ * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec
+ * does not support concatenation of compressed data, when encryption is enabled, or when
+ * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
* This code path might also be faster in cases where individual partition size in a spill
* is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small
* disk ios which is inefficient. In those case, Using large buffers for input and output
* files helps reducing the number of disk ios, making the file merging faster.
*
* @param spills the spills to merge.
- * @param outputFile the file to write the merged data to.
+ * @param mapWriter the map output writer to use for output.
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithFileStream(
+ private void mergeSpillsWithFileStream(
SpillInfo[] spills,
- File outputFile,
+ ShuffleMapOutputWriter mapWriter,
@Nullable CompressionCodec compressionCodec) throws IOException {
- assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];
- final OutputStream bos = new BufferedOutputStream(
- new FileOutputStream(outputFile),
- outputBufferSizeInBytes);
- // Use a counting output stream to avoid having to close the underlying file and ask
- // the file system for its size after each partition is written.
- final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos);
-
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new NioBufferedFileInputStream(
- spills[i].file,
- inputBufferSizeInBytes);
+ spills[i].file,
+ inputBufferSizeInBytes);
}
for (int partition = 0; partition < numPartitions; partition++) {
- final long initialFileLength = mergedFileOutputStream.getByteCount();
- // Shield the underlying output stream from close() and flush() calls, so that we can close
- // the higher level streams to make sure all data is really flushed and internal state is
- // cleaned.
- OutputStream partitionOutput = new CloseAndFlushShieldOutputStream(
- new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
- partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
- if (compressionCodec != null) {
- partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
- }
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- if (partitionLengthInSpill > 0) {
- InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
- partitionLengthInSpill, false);
- try {
- partitionInputStream = blockManager.serializerManager().wrapForEncryption(
- partitionInputStream);
- if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ boolean copyThrewException = true;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ OutputStream partitionOutput = writer.openStream();
+ try {
+ partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
+ if (compressionCodec != null) {
+ partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
+ }
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream = null;
+ boolean copySpillThrewException = true;
+ try {
+ partitionInputStream = new LimitedInputStream(spillInputStreams[i],
+ partitionLengthInSpill, false);
+ partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+ partitionInputStream);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(
+ partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, partitionOutput);
+ copySpillThrewException = false;
+ } finally {
+ Closeables.close(partitionInputStream, copySpillThrewException);
}
- ByteStreams.copy(partitionInputStream, partitionOutput);
- } finally {
- partitionInputStream.close();
}
}
+ copyThrewException = false;
+ } finally {
+ Closeables.close(partitionOutput, copyThrewException);
}
- partitionOutput.flush();
- partitionOutput.close();
- partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
+ long numBytesWritten = writer.getNumBytesWritten();
+ writeMetrics.incBytesWritten(numBytesWritten);
}
threwException = false;
} finally {
@@ -420,9 +422,7 @@ private long[] mergeSpillsWithFileStream(
for (InputStream stream : spillInputStreams) {
Closeables.close(stream, threwException);
}
- Closeables.close(mergedFileOutputStream, threwException);
}
- return partitionLengths;
}
/**
@@ -430,54 +430,46 @@ private long[] mergeSpillsWithFileStream(
* This is only safe when the IO compression codec and serializer support concatenation of
* serialized streams.
*
+ * @param spills the spills to merge.
+ * @param mapWriter the map output writer to use for output.
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
- assert (spills.length >= 2);
+ private void mergeSpillsWithTransferTo(
+ SpillInfo[] spills,
+ ShuffleMapOutputWriter mapWriter) throws IOException {
final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
- FileChannel mergedFileOutputChannel = null;
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
}
- // This file needs to opened in append mode in order to work around a Linux kernel bug that
- // affects transferTo; see SPARK-3948 for more details.
- mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
-
- long bytesWrittenToMergedFile = 0;
for (int partition = 0; partition < numPartitions; partition++) {
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- final FileChannel spillInputChannel = spillInputChannels[i];
- final long writeStartTime = System.nanoTime();
- Utils.copyFileStreamNIO(
- spillInputChannel,
- mergedFileOutputChannel,
- spillInputChannelPositions[i],
- partitionLengthInSpill);
- spillInputChannelPositions[i] += partitionLengthInSpill;
- writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
- bytesWrittenToMergedFile += partitionLengthInSpill;
- partitionLengths[partition] += partitionLengthInSpill;
+ boolean copyThrewException = true;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper()
+ .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer)));
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ Utils.copyFileStreamNIO(
+ spillInputChannel,
+ resolvedChannel.channel(),
+ spillInputChannelPositions[i],
+ partitionLengthInSpill);
+ copyThrewException = false;
+ spillInputChannelPositions[i] += partitionLengthInSpill;
+ writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
+ }
+ } finally {
+ Closeables.close(resolvedChannel, copyThrewException);
}
- }
- // Check the position after transferTo loop to see if it is in the right position and raise an
- // exception if it is incorrect. The position will not be increased to the expected length
- // after calling transferTo in kernel version 2.6.32. This issue is described at
- // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
- if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
- throw new IOException(
- "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
- "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
- " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
- "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
- "to disable this NIO feature."
- );
+ long numBytes = writer.getNumBytesWritten();
+ writeMetrics.incBytesWritten(numBytes);
}
threwException = false;
} finally {
@@ -487,9 +479,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
assert(spillInputChannelPositions[i] == spills[i].file.length());
Closeables.close(spillInputChannels[i], threwException);
}
- Closeables.close(mergedFileOutputChannel, threwException);
}
- return partitionLengths;
}
@Override
@@ -518,4 +508,30 @@ public Option stop(boolean success) {
}
}
}
+
+ private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) {
+ try {
+ return writer.openStream();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper {
+ private final WritableByteChannel channel;
+
+ StreamFallbackChannelWrapper(OutputStream fallbackStream) {
+ this.channel = Channels.newChannel(fallbackStream);
+ }
+
+ @Override
+ public WritableByteChannel channel() {
+ return channel;
+ }
+
+ @Override
+ public void close() throws IOException {
+ channel.close();
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
index 02eb710737285..47aa2e39fe29b 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.sort.io;
+import java.util.Optional;
+
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.SparkConf;
@@ -24,6 +26,7 @@
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.storage.BlockManager;
public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents {
@@ -68,4 +71,16 @@ public ShuffleMapOutputWriter createMapOutputWriter(
return new LocalDiskShuffleMapOutputWriter(
shuffleId, mapId, numPartitions, blockResolver, sparkConf);
}
+
+ @Override
+ public Optional createSingleFileMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ long mapTaskAttemptId) {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.");
+ }
+ return Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver));
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
index 7fc19b1270a46..444cdc4270ecd 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -24,8 +24,8 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
-
import java.util.Optional;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,6 +54,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
private final int bufferSize;
private int lastPartitionId = -1;
private long currChannelPosition;
+ private long bytesWrittenToMergedFile = 0L;
private final File outputFile;
private File outputTempFile;
@@ -97,6 +98,18 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I
@Override
public long[] commitAllPartitions() throws IOException {
+ // Check the position after transferTo loop to see if it is in the right position and raise a
+ // exception if it is incorrect. The position will not be increased to the expected length
+ // after calling transferTo in kernel version 2.6.32. This issue is described at
+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+ if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) {
+ throw new IOException(
+ "Current position " + outputFileChannel.position() + " does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " +
+ " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " +
+ "to unexpected behavior when using transferTo. You can set " +
+ "spark.file.transferTo=false to disable this NIO feature.");
+ }
cleanUp();
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
@@ -133,11 +146,10 @@ private void initStream() throws IOException {
}
private void initChannel() throws IOException {
- if (outputFileStream == null) {
- outputFileStream = new FileOutputStream(outputTempFile, true);
- }
+ // This file needs to opened in append mode in order to work around a Linux kernel bug that
+ // affects transferTo; see SPARK-3948 for more details.
if (outputFileChannel == null) {
- outputFileChannel = outputFileStream.getChannel();
+ outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel();
}
}
@@ -227,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException {
public void close() {
isClosed = true;
partitionLengths[partitionId] = count;
+ bytesWrittenToMergedFile += count;
}
private void verifyNotClosed() {
@@ -257,6 +270,7 @@ public WritableByteChannel channel() {
@Override
public void close() throws IOException {
partitionLengths[partitionId] = getCount();
+ bytesWrittenToMergedFile += partitionLengths[partitionId];
}
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
new file mode 100644
index 0000000000000..6b0a797a61b52
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
+import org.apache.spark.util.Utils;
+
+public class LocalDiskSingleSpillMapOutputWriter
+ implements SingleSpillShuffleMapOutputWriter {
+
+ private final int shuffleId;
+ private final int mapId;
+ private final IndexShuffleBlockResolver blockResolver;
+
+ public LocalDiskSingleSpillMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ IndexShuffleBlockResolver blockResolver) {
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.blockResolver = blockResolver;
+ }
+
+ @Override
+ public void transferMapSpillFile(
+ File mapSpillFile,
+ long[] partitionLengths) throws IOException {
+ // The map spill file already has the proper format, and it contains all of the partition data.
+ // So just transfer it directly to the destination without any merging.
+ File outputFile = blockResolver.getDataFile(shuffleId, mapId);
+ File tempFile = Utils.tempFileWith(outputFile);
+ Files.move(mapSpillFile.toPath(), tempFile.toPath());
+ blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index b898413ac8d76..158a4b7cfa55a 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1026,7 +1026,7 @@ package object config {
.booleanConf
.createWithDefault(false)
- private[spark] val SHUFFLE_UNDAFE_FAST_MERGE_ENABLE =
+ private[spark] val SHUFFLE_UNSAFE_FAST_MERGE_ENABLE =
ConfigBuilder("spark.shuffle.unsafe.fastMergeEnabled")
.doc("Whether to perform a fast spill merge.")
.booleanConf
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 2a99c93b32af4..cbdc2c886dd9f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -140,13 +140,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
- shuffleBlockResolver,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
- metrics)
+ metrics,
+ shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 6b83a984f037c..1022111897a49 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -19,8 +19,10 @@
import java.io.*;
import java.nio.ByteBuffer;
+import java.nio.file.Files;
import java.util.*;
+import org.mockito.stubbing.Answer;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
@@ -53,6 +55,7 @@
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -65,6 +68,7 @@
public class UnsafeShuffleWriterSuite {
+ static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int NUM_PARTITITONS = 4;
TestMemoryManager memoryManager;
TaskMemoryManager taskMemoryManager;
@@ -132,14 +136,28 @@ public void setUp() throws IOException {
});
when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
- doAnswer(invocationOnMock -> {
+
+ Answer> renameTempAnswer = invocationOnMock -> {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
File tmp = (File) invocationOnMock.getArguments()[3];
- mergedOutputFile.delete();
- tmp.renameTo(mergedOutputFile);
+ if (!mergedOutputFile.delete()) {
+ throw new RuntimeException("Failed to delete old merged output file.");
+ }
+ if (tmp != null) {
+ Files.move(tmp.toPath(), mergedOutputFile.toPath());
+ } else if (!mergedOutputFile.createNewFile()) {
+ throw new RuntimeException("Failed to create empty merged output file.");
+ }
return null;
- }).when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
+ };
+
+ doAnswer(renameTempAnswer)
+ .when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
+
+ doAnswer(renameTempAnswer)
+ .when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
@@ -151,21 +169,20 @@ public void setUp() throws IOException {
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
}
- private UnsafeShuffleWriter createWriter(
- boolean transferToEnabled) throws IOException {
+ private UnsafeShuffleWriter createWriter(boolean transferToEnabled) {
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
- return new UnsafeShuffleWriter<>(
+ return new UnsafeShuffleWriter(
blockManager,
- shuffleBlockResolver,
- taskMemoryManager,
+ taskMemoryManager,
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics()
- );
+ taskContext.taskMetrics().shuffleWriteMetrics(),
+ new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
}
private void assertSpillFilesWereCleanedUp() {
@@ -391,7 +408,7 @@ public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Except
@Test
public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
- conf.set(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE(), false);
+ conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false);
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
}
@@ -444,10 +461,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro
}
private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
+ memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
final UnsafeShuffleWriter writer = createWriter(false);
final ArrayList> dataToWrite = new ArrayList<>();
- for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
+ for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
dataToWrite.add(new Tuple2<>(i, i));
}
writer.write(dataToWrite.iterator());
@@ -516,16 +533,15 @@ public void testPeakMemoryUsed() throws Exception {
final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
taskMemoryManager = spy(taskMemoryManager);
when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
- final UnsafeShuffleWriter writer =
- new UnsafeShuffleWriter<>(
+ final UnsafeShuffleWriter writer = new UnsafeShuffleWriter(
blockManager,
- shuffleBlockResolver,
taskMemoryManager,
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics());
+ taskContext.taskMetrics().shuffleWriteMetrics(),
+ new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.