From 2e03ee6faec0984eeab0ffe2699ec7cb59bf0c43 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 17 Nov 2016 17:07:03 -0800 Subject: [PATCH 1/5] [SPARK-18546][core] Fix merging shuffle spills when using encryption. The problem exists because it's not possible to just concatenate encrypted partition data from different spill files; currently each partition would have its own initial vector to set up encryption, and the final merged file should contain a single initial vector for each merged partiton, otherwise iterating over each record becomes really hard. To fix that, UnsafeShuffleWriter now decrypts the partitions when merging, so that the merged file contains a single initial vector at the start of the partition data. Because it's not possible to do that using the fast transferTo path, when encryption is enabled UnsafeShuffleWriter will revert back to using file streams when merging. It may be possible to use a hybrid approach when using encryption, using an intermediate direct buffer when reading from files and encrypting the data, but that's better left for a separate patch. As part of the change I made DiskBlockObjectWriter take a SerializerManager instead of a "wrap stream" closure, since that makes it easier to test the code without having to mock SerializerManager functionality. Tested with newly added unit tests (UnsafeShuffleWriterSuite for the write side and ExternalAppendOnlyMapSuite for integration), and by running some apps that failed without the fix. --- .../shuffle/sort/UnsafeShuffleWriter.java | 43 +++++---- .../spark/serializer/SerializerManager.scala | 10 +-- .../apache/spark/storage/BlockManager.scala | 5 +- .../spark/storage/DiskBlockObjectWriter.scala | 6 +- .../sort/UnsafeShuffleWriterSuite.java | 89 +++++++++++++------ .../map/AbstractBytesToBytesMapSuite.java | 9 +- .../sort/UnsafeExternalSorterSuite.java | 19 ++-- .../BypassMergeSortShuffleWriterSuite.scala | 5 +- .../storage/DiskBlockObjectWriterSuite.scala | 54 ++++------- .../ExternalAppendOnlyMapSuite.scala | 11 ++- 10 files changed, 133 insertions(+), 118 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c434be7b1..9f21d777f2a24 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionKey().isDefined(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -337,7 +340,8 @@ private long[] mergeSpillsWithFileStream( final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; - OutputStream mergedFileOutputStream = null; + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( + new FileOutputStream(outputFile)); boolean threwException = true; try { @@ -345,34 +349,29 @@ private long[] mergeSpillsWithFileStream( spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + OutputStream partitionOutput = new CloseShieldOutputStream(mergedFileOutputStream); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { - mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean innerThrewException = true; - try { - partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); - } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); - innerThrewException = false; - } finally { - Closeables.close(partitionInputStream, innerThrewException); + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } + ByteStreams.copy(partitionInputStream, partitionOutput); } } - mergedFileOutputStream.flush(); - mergedFileOutputStream.close(); - partitionLengths[partition] = (outputFile.length() - initialFileLength); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index ef8432ec0834a..eb7f0f3c31908 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea private[spark] class SerializerManager( defaultSerializer: Serializer, conf: SparkConf, - encryptionKey: Option[Array[Byte]]) { + val encryptionKey: Option[Array[Byte]]) { def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) @@ -126,7 +126,7 @@ private[spark] class SerializerManager( /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: InputStream): InputStream = { + def wrapForEncryption(s: InputStream): InputStream = { encryptionKey .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } .getOrElse(s) @@ -135,7 +135,7 @@ private[spark] class SerializerManager( /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: OutputStream): OutputStream = { + def wrapForEncryption(s: OutputStream): OutputStream = { encryptionKey .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } .getOrElse(s) @@ -144,14 +144,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b83324e0fc..04521c9159eac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -62,7 +62,7 @@ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - serializerManager: SerializerManager, + val serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -745,9 +745,8 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, + new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827ae1598..3cb12fca7dccb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils /** @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class DiskBlockObjectWriter( val file: File, + serializerManager: SerializerManager, serializerInstance: SerializerInstance, bufferSize: Int, - wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + bs = serializerManager.wrapStream(blockId, mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a96cd82382e2c..cc88c187e65fc 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -19,6 +19,7 @@ import java.io.*; import java.nio.ByteBuffer; +import java.security.PrivilegedExceptionAction; import java.util.*; import scala.Option; @@ -40,9 +41,11 @@ import org.mockito.stubbing.Answer; import org.apache.spark.HashPartitioner; +import org.apache.spark.SecurityManager; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.deploy.SparkHadoopUtil; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -53,6 +56,7 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.*; @@ -77,7 +81,6 @@ public class UnsafeShuffleWriterSuite { final LinkedList spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); - final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -86,17 +89,6 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; - private final class WrapStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); - } else { - return stream; - } - } - } - @After public void tearDown() { Utils.deleteRecursively(tempDir); @@ -121,6 +113,11 @@ public void setUp() throws IOException { memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + // Some tests will override this manager because they change the configuration. This is a + // default for tests that don't need a specific one. + SerializerManager manager = new SerializerManager(serializer, conf); + when(blockManager.serializerManager()).thenReturn(manager); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( any(BlockId.class), @@ -131,12 +128,11 @@ public void setUp() throws IOException { @Override public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { Object[] args = invocationOnMock.getArguments(); - return new DiskBlockObjectWriter( (File) args[1], + blockManager.serializerManager(), (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] @@ -201,9 +197,10 @@ private List> readRecordsFromFile() throws IOException { for (int i = 0; i < NUM_PARTITITONS; i++) { final long partitionSize = partitionSizesInMergedFile[i]; if (partitionSize > 0) { - InputStream in = new FileInputStream(mergedOutputFile); - ByteStreams.skipFully(in, startOffset); - in = new LimitedInputStream(in, partitionSize); + FileInputStream fin = new FileInputStream(mergedOutputFile); + fin.getChannel().position(startOffset); + InputStream in = new LimitedInputStream(fin, partitionSize); + in = blockManager.serializerManager().wrapForEncryption(in); if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } @@ -294,14 +291,32 @@ public void writeWithoutSpilling() throws Exception { } private void testMergingSpills( - boolean transferToEnabled, - String compressionCodecName) throws IOException { + final boolean transferToEnabled, + String compressionCodecName, + boolean encrypt) throws Exception { if (compressionCodecName != null) { conf.set("spark.shuffle.compress", "true"); conf.set("spark.io.compression.codec", compressionCodecName); } else { conf.set("spark.shuffle.compress", "false"); } + conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt); + + SerializerManager manager; + if (encrypt) { + manager = new SerializerManager(serializer, conf, + Option.apply(CryptoStreamUtils.createKey(conf))); + } else { + manager = new SerializerManager(serializer, conf); + } + + when(blockManager.serializerManager()).thenReturn(manager); + testMergingSpills(transferToEnabled, encrypt); + } + + private void testMergingSpills( + boolean transferToEnabled, + boolean encrypted) throws IOException { final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { @@ -324,6 +339,7 @@ private void testMergingSpills( for (long size: partitionSizesInMergedFile) { sumOfPartitionSizes += size; } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); @@ -338,42 +354,60 @@ private void testMergingSpills( @Test public void mergeSpillsWithTransferToAndLZF() throws Exception { - testMergingSpills(true, LZFCompressionCodec.class.getName()); + testMergingSpills(true, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZF() throws Exception { - testMergingSpills(false, LZFCompressionCodec.class.getName()); + testMergingSpills(false, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndLZ4() throws Exception { - testMergingSpills(true, LZ4CompressionCodec.class.getName()); + testMergingSpills(true, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZ4() throws Exception { - testMergingSpills(false, LZ4CompressionCodec.class.getName()); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndSnappy() throws Exception { - testMergingSpills(true, SnappyCompressionCodec.class.getName()); + testMergingSpills(true, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndSnappy() throws Exception { - testMergingSpills(false, SnappyCompressionCodec.class.getName()); + testMergingSpills(false, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndNoCompression() throws Exception { - testMergingSpills(true, null); + testMergingSpills(true, null, false); } @Test public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { - testMergingSpills(false, null); + testMergingSpills(false, null, false); + } + + @Test + public void mergeSpillsWithCompressionAndEncryption() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { + conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false"); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); } @Test @@ -531,4 +565,5 @@ public void testPeakMemoryUsed() throws Exception { writer.stop(false); } } + } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 33709b454c4c9..45d058a4de39d 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -75,13 +75,6 @@ public abstract class AbstractBytesToBytesMapSuite { @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - private static final class WrapStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - @Before public void setup() { memoryManager = @@ -120,9 +113,9 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a9cf8ff520ed4..a9568763e6dff 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -57,13 +57,15 @@ public class UnsafeExternalSorterSuite { + private final SparkConf conf = new SparkConf(); + final LinkedList spillFilesCreated = new LinkedList<>(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final SerializerManager serializerManager = new SerializerManager( - new JavaSerializer(new SparkConf()), - new SparkConf().set("spark.shuffle.spill.compress", "false")); + new JavaSerializer(conf), + conf.clone().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = PrefixComparators.LONG; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -86,14 +88,7 @@ public int compare( protected boolean shouldUseRadixSort() { return false; } - private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); - - private static final class WrapStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } + private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m"); @Before public void setUp() { @@ -126,9 +121,9 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 442941685f1ae..85ccb33471048 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte )).thenAnswer(new Answer[DiskBlockObjectWriter] { override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments + val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( args(1).asInstanceOf[File], + manager, args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], - wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 684e978d11864..bfb3ac4c15bca 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.util.Utils class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("verify write metrics") { + private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = { val file = new File(tempDir, "somefile") + val conf = new SparkConf() + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true, + writeMetrics) + (writer, file, writeMetrics) + } + + test("verify write metrics") { + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("verify write metrics on revert") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("Reopening a closed block writer") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() writer.open() writer.close() @@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() after commit() should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("revertPartialWritesAndClose() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() without ever opening or writing") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() val segment = writer.commitAndGet() writer.close() assert(segment.length === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 5141e36d9e38d..e30cc3cced1f6 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.util.collection +import java.security.PrivilegedExceptionAction + import scala.collection.mutable.ArrayBuffer import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils @@ -230,14 +234,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } } + test("spilling with compression and encryption") { + testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true) + } + /** * Test spilling through simple aggregations and cogroups. * If a compression codec is provided, use it. Otherwise, do not compress spills. */ - private def testSimpleSpilling(codec: Option[String] = None): Unit = { + private def testSimpleSpilling(codec: Option[String] = None, encrypt: Boolean = false): Unit = { val size = 1000 val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + conf.set(IO_ENCRYPTION_ENABLED, encrypt) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) assertSpilled(sc, "reduceByKey") { From e5027904c5a85bbdf4909a49e97a74cf6fb2fc5d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 29 Nov 2016 12:52:04 -0800 Subject: [PATCH 2/5] Review feedback. --- .../shuffle/sort/UnsafeShuffleWriter.java | 19 ++++++++++++++----- .../spark/serializer/SerializerManager.scala | 4 ++-- .../sort/UnsafeShuffleWriterSuite.java | 12 ++++++++++++ .../ExternalAppendOnlyMapSuite.scala | 3 --- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9f21d777f2a24..2f5da7435ddee 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -340,6 +340,9 @@ private long[] mergeSpillsWithFileStream( final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( new FileOutputStream(outputFile)); @@ -350,6 +353,8 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() calls, so that we can close the higher + // level streams to make sure all data is really flushed and internal state is cleaned. OutputStream partitionOutput = new CloseShieldOutputStream(mergedFileOutputStream); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { @@ -361,12 +366,16 @@ private long[] mergeSpillsWithFileStream( if (partitionLengthInSpill > 0) { InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + try { + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); } - ByteStreams.copy(partitionInputStream, partitionOutput); } } partitionOutput.flush(); diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index eb7f0f3c31908..8fdb82f1feb7d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -144,14 +144,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index cc88c187e65fc..e53c5d0a12fbe 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -410,6 +410,18 @@ public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); } + @Test + public void mergeSpillsWithEncryptionAndNoCompression() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, null, true); + } + + @Test + public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception { + testMergingSpills(false, null, true); + } + @Test public void writeEnoughDataToTriggerSpill() throws Exception { memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES); diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index e30cc3cced1f6..7f0838268a111 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -17,12 +17,9 @@ package org.apache.spark.util.collection -import java.security.PrivilegedExceptionAction - import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils From 8ac927623c5d7809208b766001f46ea2ad576af9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 29 Nov 2016 12:54:13 -0800 Subject: [PATCH 3/5] Add explicit "encryptionEnabled" method. --- .../org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 2 +- .../scala/org/apache/spark/serializer/SerializerManager.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 2f5da7435ddee..20de1430dff7a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -266,7 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); - final boolean encryptionEnabled = blockManager.serializerManager().encryptionKey().isDefined(); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 8fdb82f1feb7d..185d7c18623b0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea private[spark] class SerializerManager( defaultSerializer: Serializer, conf: SparkConf, - val encryptionKey: Option[Array[Byte]]) { + encryptionKey: Option[Array[Byte]]) { def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) @@ -75,6 +75,8 @@ private[spark] class SerializerManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + def encryptionEnabled: Boolean = encryptionKey.isDefined + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } From 1025c6bb384968a7fc474d35a1bb18d82eb21938 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 29 Nov 2016 14:54:29 -0800 Subject: [PATCH 4/5] More feedback. --- .../org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 4 ++-- .../apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 5 ----- .../spark/unsafe/map/AbstractBytesToBytesMapSuite.java | 2 -- .../collection/unsafe/sort/UnsafeExternalSorterSuite.java | 2 -- 4 files changed, 2 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 20de1430dff7a..9a7d0c07ee3bb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -355,12 +355,12 @@ private long[] mergeSpillsWithFileStream( final long initialFileLength = mergedFileOutputStream.getByteCount(); // Shield the underlying output stream from close() calls, so that we can close the higher // level streams to make sure all data is really flushed and internal state is cleaned. - OutputStream partitionOutput = new CloseShieldOutputStream(mergedFileOutputStream); + OutputStream partitionOutput = new CloseShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index e53c5d0a12fbe..088b68132d905 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -19,7 +19,6 @@ import java.io.*; import java.nio.ByteBuffer; -import java.security.PrivilegedExceptionAction; import java.util.*; import scala.Option; @@ -27,11 +26,9 @@ import scala.Tuple2; import scala.Tuple2$; import scala.collection.Iterator; -import scala.runtime.AbstractFunction1; import com.google.common.collect.HashMultiset; import com.google.common.collect.Iterators; -import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -41,11 +38,9 @@ import org.mockito.stubbing.Answer; import org.apache.spark.HashPartitioner; -import org.apache.spark.SecurityManager; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; -import org.apache.spark.deploy.SparkHadoopUtil; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 45d058a4de39d..26568146bf4d7 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,13 +19,11 @@ import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Assert; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a9568763e6dff..fbbe530a132e1 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,14 +19,12 @@ import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; import java.util.UUID; import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Before; From 49737d9c574694abb8ea438bd2bb84ca65364259 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 30 Nov 2016 10:05:00 -0800 Subject: [PATCH 5/5] Update comment. --- .../org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9a7d0c07ee3bb..8a1771848dee6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -323,9 +323,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, or in - * cases where users have explicitly disabled use of {@code transferTo} in order to work around - * kernel bugs. + * cases where the IO compression codec does not support concatenation of compressed data, when + * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in + * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to.