Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
Expand Down Expand Up @@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
final boolean fastMergeIsSupported = !compressionEnabled ||
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
Expand All @@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled) {
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
} else {
Expand Down Expand Up @@ -320,9 +323,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
/**
* Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
* {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
* cases where the IO compression codec does not support concatenation of compressed data, or in
* cases where users have explicitly disabled use of {@code transferTo} in order to work around
* kernel bugs.
* cases where the IO compression codec does not support concatenation of compressed data, when
* encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in
* order to work around kernel bugs.
*
* @param spills the spills to merge.
* @param outputFile the file to write the merged data to.
Expand All @@ -337,42 +340,47 @@ private long[] mergeSpillsWithFileStream(
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new FileInputStream[spills.length];
OutputStream mergedFileOutputStream = null;

// Use a counting output stream to avoid having to close the underlying file and ask
// the file system for its size after each partition is written.
final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment about why need to use CountingOutputStream + CloseShieldOutputStream? It took me a while to figure out the optimization you did.

new FileOutputStream(outputFile));

boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new FileInputStream(spills[i].file);
}
for (int partition = 0; partition < numPartitions; partition++) {
final long initialFileLength = outputFile.length();
mergedFileOutputStream =
new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
final long initialFileLength = mergedFileOutputStream.getByteCount();
// Shield the underlying output stream from close() calls, so that we can close the higher
// level streams to make sure all data is really flushed and internal state is cleaned.
OutputStream partitionOutput = new CloseShieldOutputStream(
new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
if (compressionCodec != null) {
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
}

for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
if (partitionLengthInSpill > 0) {
InputStream partitionInputStream = null;
boolean innerThrewException = true;
InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
partitionLengthInSpill, false);
try {
partitionInputStream =
new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
partitionInputStream = blockManager.serializerManager().wrapForEncryption(
partitionInputStream);
if (compressionCodec != null) {
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
}
ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
innerThrewException = false;
ByteStreams.copy(partitionInputStream, partitionOutput);
} finally {
Closeables.close(partitionInputStream, innerThrewException);
partitionInputStream.close();
}
}
}
mergedFileOutputStream.flush();
mergedFileOutputStream.close();
partitionLengths[partition] = (outputFile.length() - initialFileLength);
partitionOutput.flush();
partitionOutput.close();
partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
}
threwException = false;
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ private[spark] class SerializerManager(
* loaded yet. */
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)

def encryptionEnabled: Boolean = encryptionKey.isDefined

def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
}
Expand Down Expand Up @@ -126,7 +128,7 @@ private[spark] class SerializerManager(
/**
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
def wrapForEncryption(s: InputStream): InputStream = {
encryptionKey
.map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
.getOrElse(s)
Expand All @@ -135,7 +137,7 @@ private[spark] class SerializerManager(
/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
def wrapForEncryption(s: OutputStream): OutputStream = {
encryptionKey
.map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
.getOrElse(s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private[spark] class BlockManager(
executorId: String,
rpcEnv: RpcEnv,
val master: BlockManagerMaster,
serializerManager: SerializerManager,
val serializerManager: SerializerManager,
val conf: SparkConf,
memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
Expand Down Expand Up @@ -745,9 +745,8 @@ private[spark] class BlockManager(
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
syncWrites, writeMetrics, blockId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel

import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.util.Utils

/**
Expand All @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils
*/
private[spark] class DiskBlockObjectWriter(
val file: File,
serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
bufferSize: Int,
wrapStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
Expand Down Expand Up @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter(
initialized = true
}

bs = wrapStream(mcs)
bs = serializerManager.wrapStream(blockId, mcs)
objOut = serializerInstance.serializeStream(bs)
streamOpen = true
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.Iterator;
import scala.runtime.AbstractFunction1;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterators;
import com.google.common.io.ByteStreams;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -53,6 +51,7 @@
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.*;
Expand All @@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite {
final LinkedList<File> spillFilesCreated = new LinkedList<>();
SparkConf conf;
final Serializer serializer = new KryoSerializer(new SparkConf());
final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf());
TaskMetrics taskMetrics;

@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
Expand All @@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite {
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;

private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
if (conf.getBoolean("spark.shuffle.compress", true)) {
return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
} else {
return stream;
}
}
}

@After
public void tearDown() {
Utils.deleteRecursively(tempDir);
Expand All @@ -121,6 +108,11 @@ public void setUp() throws IOException {
memoryManager = new TestMemoryManager(conf);
taskMemoryManager = new TaskMemoryManager(memoryManager, 0);

// Some tests will override this manager because they change the configuration. This is a
// default for tests that don't need a specific one.
SerializerManager manager = new SerializerManager(serializer, conf);
when(blockManager.serializerManager()).thenReturn(manager);

when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
any(BlockId.class),
Expand All @@ -131,12 +123,11 @@ public void setUp() throws IOException {
@Override
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
Object[] args = invocationOnMock.getArguments();

return new DiskBlockObjectWriter(
(File) args[1],
blockManager.serializerManager(),
(SerializerInstance) args[2],
(Integer) args[3],
new WrapStream(),
false,
(ShuffleWriteMetrics) args[4],
(BlockId) args[0]
Expand Down Expand Up @@ -201,9 +192,10 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
for (int i = 0; i < NUM_PARTITITONS; i++) {
final long partitionSize = partitionSizesInMergedFile[i];
if (partitionSize > 0) {
InputStream in = new FileInputStream(mergedOutputFile);
ByteStreams.skipFully(in, startOffset);
in = new LimitedInputStream(in, partitionSize);
FileInputStream fin = new FileInputStream(mergedOutputFile);
fin.getChannel().position(startOffset);
InputStream in = new LimitedInputStream(fin, partitionSize);
in = blockManager.serializerManager().wrapForEncryption(in);
if (conf.getBoolean("spark.shuffle.compress", true)) {
in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
}
Expand Down Expand Up @@ -294,14 +286,32 @@ public void writeWithoutSpilling() throws Exception {
}

private void testMergingSpills(
boolean transferToEnabled,
String compressionCodecName) throws IOException {
final boolean transferToEnabled,
String compressionCodecName,
boolean encrypt) throws Exception {
if (compressionCodecName != null) {
conf.set("spark.shuffle.compress", "true");
conf.set("spark.io.compression.codec", compressionCodecName);
} else {
conf.set("spark.shuffle.compress", "false");
}
conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);

SerializerManager manager;
if (encrypt) {
manager = new SerializerManager(serializer, conf,
Option.apply(CryptoStreamUtils.createKey(conf)));
} else {
manager = new SerializerManager(serializer, conf);
}

when(blockManager.serializerManager()).thenReturn(manager);
testMergingSpills(transferToEnabled, encrypt);
}

private void testMergingSpills(
boolean transferToEnabled,
boolean encrypted) throws IOException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
Expand All @@ -324,6 +334,7 @@ private void testMergingSpills(
for (long size: partitionSizesInMergedFile) {
sumOfPartitionSizes += size;
}

assertEquals(sumOfPartitionSizes, mergedOutputFile.length());

assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
Expand All @@ -338,42 +349,72 @@ private void testMergingSpills(

@Test
public void mergeSpillsWithTransferToAndLZF() throws Exception {
testMergingSpills(true, LZFCompressionCodec.class.getName());
testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithFileStreamAndLZF() throws Exception {
testMergingSpills(false, LZFCompressionCodec.class.getName());
testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithTransferToAndLZ4() throws Exception {
testMergingSpills(true, LZ4CompressionCodec.class.getName());
testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
testMergingSpills(false, LZ4CompressionCodec.class.getName());
testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithTransferToAndSnappy() throws Exception {
testMergingSpills(true, SnappyCompressionCodec.class.getName());
testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
testMergingSpills(false, SnappyCompressionCodec.class.getName());
testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
testMergingSpills(true, null);
testMergingSpills(true, null, false);
}

@Test
public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
testMergingSpills(false, null);
testMergingSpills(false, null, false);
}

@Test
public void mergeSpillsWithCompressionAndEncryption() throws Exception {
// This should actually be translated to a "file stream merge" internally, just have the
// test to make sure that it's the case.
testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
}

@Test
public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
}

@Test
public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test testMergingSpills(false, null, true); and testMergingSpills(true, null, true).

conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false");
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
}

@Test
public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
// This should actually be translated to a "file stream merge" internally, just have the
// test to make sure that it's the case.
testMergingSpills(true, null, true);
}

@Test
public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception {
testMergingSpills(false, null, true);
}

@Test
Expand Down Expand Up @@ -531,4 +572,5 @@ public void testPeakMemoryUsed() throws Exception {
writer.stop(false);
}
}

}
Loading