From 273487019e2c21dc354270d5cf87260424a3cfb0 Mon Sep 17 00:00:00 2001 From: kellyzly Date: Tue, 5 May 2015 15:57:17 +0800 Subject: [PATCH 1/9] [SPARK-5682] Add spark encrypted shuffle in spark Fix check style issues Clean code and fix code style issues Fix import issues clean code and fix code style issues Fix failed unit tests and code style issues add more unit tests for JCE codec update patch addressing review comments update configuration.md file and clean up code for crypto input/outputstream Remove some unused method Fix the confusing configuration for key size and buffer size and some other minor issues Add more unit test for cryptoCodec fix code style issues do some code clean Update patch by using chimera library Update patch since some changes in Chimera Fix empty space issue Simplify the usage of shuffle encryption related configurations and use the Chimera configuration directly Log error when enabling shuffle encryption in none YARN mode Add Shuffle encryption writer test Fixing some code style issues Remove unnessary unit tests Add more scala DOC Throw exception when enable shuffle file encryption in none Yarn mode Bump the Chimera version to 0.9.0 Remove code unsupported in Hadoop 2.3 Clean up test code Fix failed test case Fix code style issue Fix code style issues Update code addressing comments Mock the SparkEnv instead of instantiating a SparkContext in YarnShuffleEncryptionSuite Remove initialCredentials method in YarnShuffleEncryptionSuite Initializing the credential only once in YarnShuffleEncryptionSuite Use renaming when importing Java collection class Update patch addressing new nits Add Chimera to dep files Fix import violates Fix dep issues Fix compile error after rebasing Bump up Chimera version to 0.9.1 Update patch since changes in TaskContext Bump up version to 0.9.2 Update spark hadoop deps files Add default value for transformation and remove test word from unit test description Address a minor comment Remove chimera related from Document and add a way to customize Chimera library in Spark A fix for external sorter Use Commons-crypto library [SPARK-5682] Add spark encrypted shuffle in spark Fix check style issues Clean code and fix code style issues Fix import issues clean code and fix code style issues Fix failed unit tests and code style issues add more unit tests for JCE codec update patch addressing review comments update configuration.md file and clean up code for crypto input/outputstream Remove some unused method Fix the confusing configuration for key size and buffer size and some other minor issues Add more unit test for cryptoCodec fix code style issues do some code clean Update patch by using chimera library Update patch since some changes in Chimera Fix empty space issue Simplify the usage of shuffle encryption related configurations and use the Chimera configuration directly Log error when enabling shuffle encryption in none YARN mode Add Shuffle encryption writer test Fixing some code style issues Remove unnessary unit tests Add more scala DOC Throw exception when enable shuffle file encryption in none Yarn mode Bump the Chimera version to 0.9.0 Remove code unsupported in Hadoop 2.3 Clean up test code Fix failed test case Fix code style issue Fix code style issues Update code addressing comments Mock the SparkEnv instead of instantiating a SparkContext in YarnShuffleEncryptionSuite Remove initialCredentials method in YarnShuffleEncryptionSuite Initializing the credential only once in YarnShuffleEncryptionSuite Use renaming when importing Java collection class Update patch addressing new nits Add Chimera to dep files Fix import violates Fix dep issues Fix compile error after rebasing Bump up Chimera version to 0.9.1 Update patch since changes in TaskContext Bump up version to 0.9.2 Update spark hadoop deps files Add default value for transformation and remove test word from unit test description Address a minor comment Remove chimera related from Document and add a way to customize Chimera library in Spark A fix for external sorter --- core/pom.xml | 4 + .../unsafe/sort/UnsafeSorterSpillReader.java | 3 +- .../scala/org/apache/spark/SparkContext.scala | 5 + .../org/apache/spark/crypto/CryptoConf.scala | 71 ++++ .../spark/crypto/CryptoStreamUtils.scala | 104 +++++ .../spark/serializer/SerializerManager.scala | 23 +- .../shuffle/BlockStoreShuffleReader.scala | 5 +- .../apache/spark/storage/BlockManager.scala | 3 +- .../spark/storage/DiskBlockObjectWriter.scala | 5 +- .../spark/storage/memory/MemoryStore.scala | 4 +- .../collection/ExternalAppendOnlyMap.scala | 3 +- .../util/collection/ExternalSorter.scala | 6 +- .../sort/UnsafeShuffleWriterSuite.java | 8 + .../map/AbstractBytesToBytesMapSuite.java | 8 + .../sort/UnsafeExternalSorterSuite.java | 8 + .../spark/crypto/ShuffleEncryptionSuite.scala | 109 ++++++ .../BypassMergeSortShuffleWriterSuite.scala | 1 + .../storage/DiskBlockObjectWriterSuite.scala | 27 +- dev/deps/spark-deps-hadoop-2.2 | 2 + dev/deps/spark-deps-hadoop-2.3 | 2 + dev/deps/spark-deps-hadoop-2.4 | 2 + dev/deps/spark-deps-hadoop-2.6 | 2 + dev/deps/spark-deps-hadoop-2.7 | 2 + docs/configuration.md | 33 ++ pom.xml | 6 + .../org/apache/spark/deploy/yarn/Client.scala | 5 + .../yarn/YarnShuffleEncryptionSuite.scala | 355 ++++++++++++++++++ 27 files changed, 787 insertions(+), 19 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala create mode 100644 core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala diff --git a/core/pom.xml b/core/pom.xml index ab6c3ce80527..29c4902805db 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -332,6 +332,10 @@ org.apache.spark spark-tags_${scala.binary.version} + + org.apache.commons + commons-crypto + target/scala-${scala.binary.version}/classes diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index d048cf7aeb5f..9d09d41a4be2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,7 +72,8 @@ public UnsafeSorterSpillReader( final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes); try { - this.in = serializerManager.wrapForCompression(blockId, bs); + final InputStream eis = serializerManager.wrapForEncryption(bs); + this.in = serializerManager.wrapForCompression(blockId, eis); this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2eaeab1d807b..9bac15df99f8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -46,6 +46,7 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast +import org.apache.spark.crypto.CryptoConf import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -413,6 +414,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") + if (CryptoConf.isShuffleEncryptionEnabled(_conf) && !SparkHadoopUtil.get.isYarnMode()) { + throw new SparkException("Shuffle file encryption is only supported in Yarn mode, please " + + "disable it by setting spark.shuffle.encryption.enabled to false") + } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. diff --git a/core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala b/core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala new file mode 100644 index 000000000000..3a9dc22a41ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.crypto + +import javax.crypto.KeyGenerator + +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf + +/** + * CryptoConf is a class for Crypto configuration + */ +private[spark] object CryptoConf { + /** + * Constants and variables for spark shuffle file encryption + */ + val SPARK_SHUFFLE_TOKEN = new Text("SPARK_SHUFFLE_TOKEN") + val SPARK_SHUFFLE_ENCRYPTION_ENABLED = "spark.shuffle.encryption.enabled" + val SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM = "spark.shuffle.encryption.keygen.algorithm" + val DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM = "HmacSHA1" + val SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS = "spark.shuffle.encryption.keySizeBits" + val DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS = 128 + + /** + * Check whether shuffle file encryption is enabled. It is disabled by default. + */ + def isShuffleEncryptionEnabled(sparkConf: SparkConf): Boolean = { + if (sparkConf != null) { + sparkConf.getBoolean(SPARK_SHUFFLE_ENCRYPTION_ENABLED, false) + } else { + false + } + } + + /** + * Setup the cryptographic key used by file shuffle encryption in credentials. The key is + * generated using [[KeyGenerator]]. The algorithm and key length is specified by the + * [[SparkConf]]. + */ + def initSparkShuffleCredentials(conf: SparkConf, credentials: Credentials): Unit = { + if (credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) == null) { + val keyLen = conf.getInt(SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS, + DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS) + require(keyLen == 128 || keyLen == 192 || keyLen == 256) + val shuffleKeyGenAlgorithm = conf.get(SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM, + DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(shuffleKeyGenAlgorithm) + keyGen.init(keyLen) + + val shuffleKey = keyGen.generateKey() + credentials.addSecretKey(SPARK_SHUFFLE_TOKEN, shuffleKey.getEncoded) + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala new file mode 100644 index 000000000000..0a79485b8b04 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.crypto + +import java.io.{InputStream, OutputStream} +import java.util.Properties +import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} + +import org.apache.commons.crypto.random._ +import org.apache.commons.crypto.stream._ + +import org.apache.spark.SparkConf +import org.apache.spark.crypto.CryptoConf._ +import org.apache.spark.deploy.SparkHadoopUtil + +/** + * A util class for manipulating file shuffle encryption and decryption streams. + */ +private[spark] object CryptoStreamUtils { + // The initialization vector length in bytes. + val IV_LENGTH_IN_BYTES = 16 + // The prefix of Crypto related configurations in Spark configuration. + val SPARK_COMMONS_CRYPTO_CONF_PREFIX = "spark.commons.crypto." + // The prefix for the configurations passing to Commons-crypto library. + val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." + + + /** + * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + */ + def createCryptoOutputStream( + os: OutputStream, + sparkConf: SparkConf): OutputStream = { + val properties = toCryptoConf(sparkConf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, + COMMONS_CRYPTO_CONF_PREFIX) + val iv: Array[Byte] = createInitializationVector(properties) + os.write(iv) + val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() + val key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) + val transformationStr = sparkConf.get( + "spark.shuffle.crypto.cipher.transformation", "AES/CTR/NoPadding") + new CryptoOutputStream(transformationStr, properties, os, + new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + } + + /** + * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + */ + def createCryptoInputStream( + is: InputStream, + sparkConf: SparkConf): InputStream = { + val properties = toCryptoConf(sparkConf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, + COMMONS_CRYPTO_CONF_PREFIX) + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + is.read(iv, 0, iv.length) + val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() + val key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) + val transformationStr = sparkConf.get( + "spark.shuffle.crypto.cipher.transformation", "AES/CTR/NoPadding") + new CryptoInputStream(transformationStr, properties, is, + new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + } + + /** + * Get Commons-crypto configurations from Spark configurations identified + * by prefix. + */ + def toCryptoConf( + conf: SparkConf, + sparkPrefix: String, + cryptoPrefix: String): Properties = { + val props = new Properties() + conf.getAll.foreach { case (k, v) => + if (k.startsWith(sparkPrefix)) { + props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring( + SPARK_COMMONS_CRYPTO_CONF_PREFIX.length()), v) + } + } + props + } + + /** + * This method to generate an IV (Initialization Vector) using secure random. + */ + private[this] def createInitializationVector(properties: Properties): Array[Byte] = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv) + iv + } +} diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 07caadbe4043..be89e0cb3bbe 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag import org.apache.spark.SparkConf +import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage._ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -61,6 +62,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + // Whether to encrypt shuffle file encryption + private[this] val enableShuffleFileEncryption = CryptoConf.isShuffleEncryptionEnabled(conf) + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -102,6 +106,20 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar } } + /** + * Wrap an input stream for encryption if shuffle encryption is enabled + */ + def wrapForEncryption(s: InputStream): InputStream = { + if (enableShuffleFileEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + } + + /** + * Wrap an output stream for encryption if shuffle encryption is enabled + */ + def wrapForEncryption(s: OutputStream): OutputStream = { + if (enableShuffleFileEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + } + /** * Wrap an output stream for compression if block compression is enabled for its block type */ @@ -123,7 +141,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar values: Iterator[T]): Unit = { val byteStream = new BufferedOutputStream(outputStream) val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapForEncryption(wrapForCompression(blockId, byteStream))).writeAll( + values).close() } /** Serializes into a chunked byte buffer. */ @@ -153,7 +172,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar val stream = new BufferedInputStream(inputStream) getSerializer(implicitly[ClassTag[T]]) .newInstance() - .deserializeStream(wrapForCompression(blockId, stream)) + .deserializeStream(wrapForCompression(blockId, wrapForEncryption(stream))) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 5794f542b756..008599bc63e9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -51,9 +51,10 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) - // Wrap the streams for compression based on configuration + // Wrap the streams for compression and encryption based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - serializerManager.wrapForCompression(blockId, inputStream) + val eis = serializerManager.wrapForEncryption(inputStream) + serializerManager.wrapForCompression(blockId, eis) } val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fe8465279860..39caa9cf8ed7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -723,8 +723,9 @@ private[spark] class BlockManager( writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = serializerManager.wrapForCompression(blockId, _) + val encryptStream: OutputStream => OutputStream = serializerManager.wrapForEncryption(_) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, + new DiskBlockObjectWriter(file, serializerInstance, bufferSize, encryptStream, compressStream, syncWrites, writeMetrics, blockId) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e5b1bf2f4b43..a621fddbcb7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -40,6 +40,7 @@ private[spark] class DiskBlockObjectWriter( serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, + encryptStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -115,7 +116,9 @@ private[spark] class DiskBlockObjectWriter( initialize() initialized = true } - bs = compressStream(mcs) + + bs = encryptStream(mcs) + bs = compressStream(bs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 586339a58d23..7cab54e07db2 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -330,7 +330,9 @@ private[spark] class MemoryStore( redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { val ser = serializerManager.getSerializer(classTag).newInstance() - ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + ser.serializeStream( + serializerManager.wrapForEncryption(serializerManager.wrapForCompression(blockId, + redirectableStream))) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8c8860bb37a4..ca3eac47241c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -486,7 +486,8 @@ class ExternalAppendOnlyMap[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream) + val encryptedStream = serializerManager.wrapForEncryption(bufferedStream) + val compressedStream = serializerManager.wrapForCompression(blockId, encryptedStream) ser.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7c98e8cabb22..10154e33204c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager @@ -522,7 +523,10 @@ private[spark] class ExternalSorter[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream) + + val encryptedStream = SparkEnv.get.serializerManager.wrapForEncryption(bufferedStream) + val compressedStream = SparkEnv.get.serializerManager.wrapForCompression(spill.blockId, + encryptedStream) serInstance.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index daeb4675ea5f..e68086a02a20 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -97,6 +97,13 @@ public OutputStream apply(OutputStream stream) { } } + private static final class EncryptStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + @After public void tearDown() { Utils.deleteRecursively(tempDir); @@ -137,6 +144,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (SerializerInstance) args[2], (Integer) args[3], new CompressStream(), + new EncryptStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index fc127f07c8d6..43576adf5348 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -82,6 +82,13 @@ public OutputStream apply(OutputStream stream) { } } + private static final class EncryptStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + @Before public void setup() { memoryManager = @@ -123,6 +130,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (SerializerInstance) args[2], (Integer) args[3], new CompressStream(), + new EncryptStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 3ea99233fe17..e50d2c1d58f4 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -95,6 +95,13 @@ public OutputStream apply(OutputStream stream) { } } + private static final class EncryptStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -129,6 +136,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (SerializerInstance) args[2], (Integer) args[3], new CompressStream(), + new EncryptStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala b/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala new file mode 100644 index 000000000000..bf16428088a7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.crypto + +import java.security.PrivilegedExceptionAction + +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.crypto.CryptoConf._ +import org.apache.spark.crypto.CryptoStreamUtils.{COMMONS_CRYPTO_CONF_PREFIX, SPARK_COMMONS_CRYPTO_CONF_PREFIX} + + +private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { + val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) + + test("Crypto configuration conversion") { + val sparkKey1 = s"${SPARK_COMMONS_CRYPTO_CONF_PREFIX}a.b.c" + val sparkVal1 = "val1" + val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" + + val sparkKey2 = SPARK_COMMONS_CRYPTO_CONF_PREFIX.stripSuffix(".") + "A.b.c" + val sparkVal2 = "val2" + val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c" + val conf = new SparkConf() + conf.set(sparkKey1, sparkVal1) + conf.set(sparkKey2, sparkVal2) + val props = CryptoStreamUtils.toCryptoConf(conf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, + COMMONS_CRYPTO_CONF_PREFIX) + assert(props.getProperty(cryptoKey1) === sparkVal1) + assert(!props.containsKey(cryptoKey2)) + } + + test("Shuffle encryption is disabled by default") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + initCredentials(conf, credentials) + assert(credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) === null) + } + }) + } + + test("Shuffle encryption key length should be 128 by default") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + initCredentials(conf, credentials) + var key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) + assert(key !== null) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS) + } + }) + } + + test("Initial credentials with key length in 256") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + conf.set(SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS, 256.toString) + conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + initCredentials(conf, credentials) + var key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) + assert(key !== null) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 256) + } + }) + } + + test("Initial credentials with invalid key length") { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials() + val conf = new SparkConf() + conf.set(SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS, 328.toString) + conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + val thrown = intercept[IllegalArgumentException] { + initCredentials(conf, credentials) + } + } + }) + } + + private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { + if (CryptoConf.isShuffleEncryptionEnabled(conf)) { + CryptoConf.initSparkShuffleCredentials(conf, credentials) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 5132384a5ed7..735ffba24365 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -95,6 +95,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], compressStream = identity, + encryptStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 684e978d1186..86fa803a8a69 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -46,7 +46,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -69,7 +70,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -92,7 +94,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) writer.open() writer.close() @@ -105,7 +108,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -123,7 +127,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -139,7 +144,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -156,7 +162,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -176,7 +183,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -194,7 +202,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, + writeMetrics) val segment = writer.commitAndGet() writer.close() assert(segment.length === 0) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 326271a7e2b2..8c70c95a3b84 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -27,6 +27,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar @@ -99,6 +100,7 @@ jersey-server-2.22.2.jar jets3t-0.7.1.jar jetty-util-6.1.26.jar jline-2.12.1.jar +jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 1ff6ecb7342b..3cbb5df511d5 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -30,6 +30,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar @@ -104,6 +105,7 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar +jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 68333849cf4c..3cfa1ea3c395 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -30,6 +30,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar @@ -104,6 +105,7 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar +jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 787d06c3512d..1b984905c0c5 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -34,6 +34,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar @@ -112,6 +113,7 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar +jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 386495bf1bbb..36c1b8de755f 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -34,6 +34,7 @@ commons-collections-3.2.2.jar commons-compiler-2.7.6.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar +commons-crypto-1.0.0.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar @@ -112,6 +113,7 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar +jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/docs/configuration.md b/docs/configuration.md index 2f801961050e..c2e32f36fb04 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -559,6 +559,39 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. + + spark.shuffle.encryption.enabled + false + + Enable shuffle file encryption. + + + + spark.shuffle.encryption.keySizeBits + 128 + + Shuffle file encryption key size in bits. The valid number includes 128, 192 and 256. + + + + spark.shuffle.encryption.keygen.algorithm + HmacSHA1 + + The algorithm to generate the key used by shuffle file encryption. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. + + + + spark.shuffle.crypto.cipher.transformation + AES/CTR/NoPadding + + Cipher transformation for shuffle file encryption. The cipher transformation name is + identical to the transformations described in the Cipher section of the Java Cryptography + Architecture Standard Algorithm Name Documentation. Currently only "AES/CTR/NoPadding" + algorithm is supported. + + #### Spark UI diff --git a/pom.xml b/pom.xml index 989658216e5f..5eff64db5256 100644 --- a/pom.xml +++ b/pom.xml @@ -182,6 +182,7 @@ 2.52.0 2.8 1.8 + 1.0.0 ${java.home} @@ -1839,6 +1840,11 @@ jline ${jline.version} + + org.apache.commons + commons-crypto + ${commons-crypto.version} + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7fbbe91de94e..76a18889412c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -48,6 +48,7 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException} +import org.apache.spark.crypto.CryptoConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager @@ -1003,6 +1004,10 @@ private[spark] class Client( val securityManager = new SecurityManager(sparkConf) amContainer.setApplicationACLs( YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) + + if (CryptoConf.isShuffleEncryptionEnabled(sparkConf)) { + CryptoConf.initSparkShuffleCredentials(sparkConf, credentials) + } setupSecurityToken(amContainer) amContainer diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala new file mode 100644 index 000000000000..1fe12b95bbe4 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.io._ +import java.nio.ByteBuffer +import java.security.PrivilegedExceptionAction +import java.util.{ArrayList => JArrayList, LinkedList => JLinkedList, UUID} + +import scala.runtime.AbstractFunction1 + +import com.google.common.collect.HashMultiset +import com.google.common.io.ByteStreams +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.junit.Assert.assertEquals +import org.mockito.Mock +import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} + +import org.apache.spark._ +import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils} +import org.apache.spark.crypto.CryptoConf._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.network.buffer.NioManagedBuffer +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializerInstance, +SerializerManager} +import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, + IndexShuffleBlockResolver, RecordingManagedBuffer} +import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, UnsafeShuffleWriter} +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Matchers with + BeforeAndAfterAll with BeforeAndAfterEach { + @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private[this] var diskBlockManager: DiskBlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private[this] var serializerManager: SerializerManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private[this] var taskContext: TaskContext = _ + @Mock( + answer = RETURNS_SMART_NULLS) private[this] var shuffleDep: ShuffleDependency[Int, Int, Int] = _ + + private[this] val NUM_MAPS = 1 + private[this] val NUM_PARTITITONS = 4 + private[this] val REDUCE_ID = 1 + private[this] val SHUFFLE_ID = 0 + private[this] val conf = new SparkConf() + private[this] val memoryManager = new TestMemoryManager(conf) + private[this] val hashPartitioner = new HashPartitioner(NUM_PARTITITONS) + private[this] val serializer = new KryoSerializer(conf) + private[this] val spillFilesCreated = new JLinkedList[File]() + private[this] val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + private[this] val taskMetrics = new TaskMetrics() + + private[this] var tempDir: File = _ + private[this] var mergedOutputFile: File = _ + private[this] var partitionSizesInMergedFile: Array[Long] = _ + private[this] val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + private[this] val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(SHUFFLE_ID, NUM_MAPS, dependency) + } + + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + private[this] val mapOutputTracker = mock(classOf[MapOutputTracker]) + private[this] val sparkEnv = mock(classOf[SparkEnv]) + + override def beforeAll(): Unit = { + when(sparkEnv.conf).thenReturn(conf) + SparkEnv.set(sparkEnv) + + System.setProperty("SPARK_YARN_MODE", "true") + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + val creds = new Credentials() + CryptoConf.initSparkShuffleCredentials(conf, creds) + SparkHadoopUtil.get.addCurrentUserCredentials(creds) + } + }) + } + + override def afterAll(): Unit = { + SparkEnv.set(null) + } + + override def beforeEach(): Unit = { + super.beforeEach() + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir() + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + } + + override def afterEach(): Unit = { + super.afterEach() + conf.set("spark.shuffle.compress", false.toString) + conf.set("spark.shuffle.spill.compress", false.toString) + Utils.deleteRecursively(tempDir) + val leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (leakedMemory != 0) { + fail("Test leaked " + leakedMemory + " bytes of managed memory") + } + } + + test("yarn shuffle encryption read and write") { + ugi.doAs(new PrivilegedExceptionAction[Unit] { + override def run(): Unit = { + conf.set("spark.shuffle.compress", false.toString) + conf.set("spark.shuffle.spill.compress", false.toString) + testYarnShuffleEncryptionWriteRead() + } + }) + } + + test("yarn shuffle encryption read and write with shuffle compression enabled") { + ugi.doAs(new PrivilegedExceptionAction[Unit] { + override def run(): Unit = { + conf.set("spark.shuffle.compress", true.toString) + conf.set("spark.shuffle.spill.compress", true.toString) + testYarnShuffleEncryptionWriteRead() + } + }) + } + + private[this] def testYarnShuffleEncryptionWriteRead(): Unit = { + val dataToWrite = new JArrayList[Product2[Int, Int]]() + for (i <- 0 to NUM_PARTITITONS) { + dataToWrite.add((i, i)) + } + val shuffleWriter = createWriter() + shuffleWriter.write(dataToWrite.iterator()) + shuffleWriter.stop(true) + + val shuffleReader = createReader() + val iter = shuffleReader.read() + val recordsList = new JArrayList[(Int, Int)]() + while (iter.hasNext) { + recordsList.add(iter.next().asInstanceOf[(Int, Int)]) + } + + assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(recordsList)) + } + + private[this] def createWriter(): UnsafeShuffleWriter[Int, Int] = { + initialMocksForWriter() + new UnsafeShuffleWriter[Int, Int]( + blockManager, + blockResolver, + taskMemoryManager, + new SerializedShuffleHandle[Int, Int](SHUFFLE_ID, NUM_MAPS, shuffleDep), + 0, // map id + taskContext, + conf + ) + } + + private[this] def createReader(): BlockStoreShuffleReader[Int, Int] = { + initialMocksForReader() + + new BlockStoreShuffleReader( + shuffleHandle, + REDUCE_ID, + REDUCE_ID + 1, + TaskContext.empty(), + serializerManager, + blockManager, + mapOutputTracker) + } + + private[this] def initialMocksForWriter(): Unit = { + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.conf).thenReturn(conf) + when(blockManager.getDiskWriter(any(classOf[BlockId]), any(classOf[File]), + any(classOf[SerializerInstance]), anyInt, any(classOf[ShuffleWriteMetrics]))).thenAnswer( + new Answer[DiskBlockObjectWriter]() { + override def answer(invocationOnMock: InvocationOnMock): DiskBlockObjectWriter = { + val args = invocationOnMock.getArguments + new DiskBlockObjectWriter(args(1).asInstanceOf[File], + args(2).asInstanceOf[SerializerInstance], + args(3).asInstanceOf[Integer], new CompressStream(), new EncryptStream(), false, + args(4).asInstanceOf[ShuffleWriteMetrics], args(0).asInstanceOf[BlockId]) + } + }) + + when(blockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile) + doAnswer(new Answer[Unit]() { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp = invocationOnMock.getArguments()(3) + mergedOutputFile.delete() + tmp.asInstanceOf[File].renameTo(mergedOutputFile) + } + }).when(blockResolver).writeIndexFileAndCommit(anyInt(), anyInt(), any(classOf[Array[Long]]), + any(classOf[File])) + + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer[(TempShuffleBlockId, File)]() { + override def answer(invocationOnMock: InvocationOnMock): (TempShuffleBlockId, File) = { + val blockId = new TempShuffleBlockId(UUID.randomUUID()) + val file = File.createTempFile("spillFile", ".spill", tempDir) + spillFilesCreated.add(file) + (blockId, file) + } + }) + + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(shuffleDep.serializer).thenReturn(serializer) + when(shuffleDep.partitioner).thenReturn(hashPartitioner) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + } + + private[this] def initialMocksForReader(): Unit = { + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + + // Create a return function to use for the mocked wrapForCompression method to initial a + // compressed input stream if spark.shuffle.compress is enabled + val compressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = { + if (conf.getBoolean("spark.shuffle.compress", false)) { + CompressionCodec.createCodec(conf).compressedInputStream( + invocation.getArguments()(1).asInstanceOf[InputStream]) + } else { + invocation.getArguments()(1).asInstanceOf[InputStream] + } + } + } + // Create a return function to use for the mocked wrapForEncryption method to initial a + // encrypted input stream if spark.shuffle.encryption.enabled is enabled + val encryptionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = { + if (CryptoConf.isShuffleEncryptionEnabled(conf)) { + CryptoStreamUtils.createCryptoInputStream( + invocation.getArguments()(0).asInstanceOf[InputStream], conf) + } else { + invocation.getArguments()(0).asInstanceOf[InputStream] + } + } + } + var startOffset = 0L + for (mapId <- 0 until NUM_PARTITITONS) { + val partitionSize: Long = partitionSizesInMergedFile(mapId) + if (partitionSize > 0) { + val bytes = new Array[Byte](partitionSize.toInt) + var in: InputStream = new FileInputStream(mergedOutputFile) + ByteStreams.skipFully(in, startOffset) + in = new LimitedInputStream(in, partitionSize) + try { + in.read(bytes) + } finally { + in.close() + } + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(bytes)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + startOffset += partitionSizesInMergedFile(mapId) + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(SHUFFLE_ID, mapId, REDUCE_ID) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(serializerManager.wrapForCompression(meq(shuffleBlockId), + isA(classOf[InputStream]))).thenAnswer(compressionFunction) + when(serializerManager.wrapForEncryption(isA(classOf[InputStream]))).thenAnswer( + encryptionFunction) + } + } + + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until NUM_PARTITITONS).map { mapId => + val shuffleBlockId = ShuffleBlockId(SHUFFLE_ID, mapId, REDUCE_ID) + (shuffleBlockId, partitionSizesInMergedFile(mapId)) + } + val mapSizesByExecutorId = Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + when(mapOutputTracker.getMapSizesByExecutorId(SHUFFLE_ID, REDUCE_ID, REDUCE_ID + 1)).thenReturn + { + mapSizesByExecutorId + } + } + + @throws(classOf[IOException]) + private def readRecordsFromFile: JArrayList[(Any, Any)] = { + val recordsList: JArrayList[(Any, Any)] = new JArrayList[(Any, Any)] + var startOffset = 0L + for (mapId <- 0 until NUM_PARTITITONS) { + val partitionSize: Long = partitionSizesInMergedFile(mapId) + if (partitionSize > 0) { + var in: InputStream = new FileInputStream(mergedOutputFile) + ByteStreams.skipFully(in, startOffset) + in = new LimitedInputStream(in, partitionSize) + val recordsStream: DeserializationStream = serializer.newInstance.deserializeStream(in) + val records: Iterator[(Any, Any)] = recordsStream.asKeyValueIterator + while (records.hasNext) { + val record: (Any, Any) = records.next + assertEquals(mapId, hashPartitioner.getPartition(record._1)) + recordsList.add(record) + } + recordsStream.close + startOffset += partitionSize + } + } + recordsList + } + + private[this] final class CompressStream extends AbstractFunction1[OutputStream, OutputStream] { + override def apply(stream: OutputStream): OutputStream = { + if (conf.getBoolean("spark.shuffle.compress", false)) { + CompressionCodec.createCodec(conf).compressedOutputStream(stream) + } else { + stream + } + } + } + + private[this] final class EncryptStream extends AbstractFunction1[OutputStream, OutputStream] { + override def apply(stream: OutputStream): OutputStream = { + if (CryptoConf.isShuffleEncryptionEnabled(conf)) { + CryptoStreamUtils.createCryptoOutputStream(stream, conf) + } else { + stream + } + } + } +} From 66e29b2fca0cfab020b4ed33d1b8309c5a334e66 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Tue, 23 Aug 2016 03:22:23 +0800 Subject: [PATCH 2/9] Update patch using new Config APIs and simplify SerializerManager interface --- .../unsafe/sort/UnsafeSorterSpillReader.java | 3 +- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../org/apache/spark/crypto/CryptoConf.scala | 71 --------------- .../spark/internal/config/package.scala | 20 +++++ .../CryptoStreamUtils.scala | 28 +++--- .../spark/serializer/SerializerManager.scala | 55 +++++++++--- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 6 +- .../spark/storage/DiskBlockObjectWriter.scala | 6 +- .../spark/storage/memory/MemoryStore.scala | 4 +- .../collection/ExternalAppendOnlyMap.scala | 5 +- .../util/collection/ExternalSorter.scala | 8 +- .../sort/UnsafeShuffleWriterSuite.java | 12 +-- .../map/AbstractBytesToBytesMapSuite.java | 12 +-- .../sort/UnsafeExternalSorterSuite.java | 12 +-- .../spark/crypto/ShuffleEncryptionSuite.scala | 29 +++--- .../BypassMergeSortShuffleWriterSuite.scala | 3 +- .../storage/DiskBlockObjectWriterSuite.scala | 27 ++---- dev/deps/spark-deps-hadoop-2.2 | 1 - dev/deps/spark-deps-hadoop-2.3 | 1 - dev/deps/spark-deps-hadoop-2.4 | 1 - dev/deps/spark-deps-hadoop-2.6 | 1 - dev/deps/spark-deps-hadoop-2.7 | 1 - docs/configuration.md | 22 ++--- pom.xml | 6 ++ .../org/apache/spark/deploy/yarn/Client.scala | 6 +- ...uite.scala => YarnIOEncryptionSuite.scala} | 89 ++++++++----------- 27 files changed, 181 insertions(+), 259 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala rename core/src/main/scala/org/apache/spark/{crypto => security}/CryptoStreamUtils.scala (83%) rename yarn/src/test/scala/org/apache/spark/deploy/yarn/{YarnShuffleEncryptionSuite.scala => YarnIOEncryptionSuite.scala} (82%) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 9d09d41a4be2..2875b0d69def 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,8 +72,7 @@ public UnsafeSorterSpillReader( final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes); try { - final InputStream eis = serializerManager.wrapForEncryption(bs); - this.in = serializerManager.wrapForCompression(blockId, eis); + this.in = serializerManager.wrapStream(blockId, bs); this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9bac15df99f8..8580c94c7304 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -46,11 +46,11 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast -import org.apache.spark.crypto.CryptoConf import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -414,9 +414,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (CryptoConf.isShuffleEncryptionEnabled(_conf) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("Shuffle file encryption is only supported in Yarn mode, please " + - "disable it by setting spark.shuffle.encryption.enabled to false") + if (_conf.get(SPARK_IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { + throw new SparkException("IO encryption is only supported in Yarn mode, please disable it " + + "by setting spark.io.encryption.enabled to false") } // "_jobProgressListener" should be set up before creating SparkEnv because when creating diff --git a/core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala b/core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala deleted file mode 100644 index 3a9dc22a41ee..000000000000 --- a/core/src/main/scala/org/apache/spark/crypto/CryptoConf.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.crypto - -import javax.crypto.KeyGenerator - -import org.apache.hadoop.io.Text -import org.apache.hadoop.security.Credentials - -import org.apache.spark.SparkConf - -/** - * CryptoConf is a class for Crypto configuration - */ -private[spark] object CryptoConf { - /** - * Constants and variables for spark shuffle file encryption - */ - val SPARK_SHUFFLE_TOKEN = new Text("SPARK_SHUFFLE_TOKEN") - val SPARK_SHUFFLE_ENCRYPTION_ENABLED = "spark.shuffle.encryption.enabled" - val SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM = "spark.shuffle.encryption.keygen.algorithm" - val DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM = "HmacSHA1" - val SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS = "spark.shuffle.encryption.keySizeBits" - val DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS = 128 - - /** - * Check whether shuffle file encryption is enabled. It is disabled by default. - */ - def isShuffleEncryptionEnabled(sparkConf: SparkConf): Boolean = { - if (sparkConf != null) { - sparkConf.getBoolean(SPARK_SHUFFLE_ENCRYPTION_ENABLED, false) - } else { - false - } - } - - /** - * Setup the cryptographic key used by file shuffle encryption in credentials. The key is - * generated using [[KeyGenerator]]. The algorithm and key length is specified by the - * [[SparkConf]]. - */ - def initSparkShuffleCredentials(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) == null) { - val keyLen = conf.getInt(SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS, - DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS) - require(keyLen == 128 || keyLen == 192 || keyLen == 256) - val shuffleKeyGenAlgorithm = conf.get(SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM, - DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(shuffleKeyGenAlgorithm) - keyGen.init(keyLen) - - val shuffleKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_SHUFFLE_TOKEN, shuffleKey.getEncoded) - } - } -} - diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 47174e4efee8..9a0224e09da2 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -119,4 +119,24 @@ package object config { private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks") .intConf .createWithDefault(100000) + + private[spark] val SPARK_IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM = ConfigBuilder( + "spark.io.encryption.keygen.algorithm") + .stringConf + .createWithDefault("HmacSHA1") + + private[spark] val SPARK_IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder( + "spark.io.encryption.keySizeBits") + .intConf + .checkValues(Set(128, 192, 256)) + .createWithDefault(128) + + private[spark] val SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION = ConfigBuilder( + "spark.io.crypto.cipher.transformation") + .stringConf + .createWithDefaultString("AES/CTR/NoPadding") } diff --git a/core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala rename to core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 0a79485b8b04..5ba020c068bc 100644 --- a/core/src/main/scala/org/apache/spark/crypto/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.crypto +package org.apache.spark.security import java.io.{InputStream, OutputStream} import java.util.Properties @@ -22,15 +22,21 @@ import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ +import org.apache.hadoop.io.Text import org.apache.spark.SparkConf -import org.apache.spark.crypto.CryptoConf._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.config._ /** - * A util class for manipulating file shuffle encryption and decryption streams. + * A util class for manipulating IO encryption and decryption streams. */ private[spark] object CryptoStreamUtils { + /** + * Constants and variables for spark IO encryption + */ + val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") + // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 // The prefix of Crypto related configurations in Spark configuration. @@ -38,7 +44,6 @@ private[spark] object CryptoStreamUtils { // The prefix for the configurations passing to Commons-crypto library. val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." - /** * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. */ @@ -47,12 +52,11 @@ private[spark] object CryptoStreamUtils { sparkConf: SparkConf): OutputStream = { val properties = toCryptoConf(sparkConf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, COMMONS_CRYPTO_CONF_PREFIX) - val iv: Array[Byte] = createInitializationVector(properties) + val iv = createInitializationVector(properties) os.write(iv) val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) - val transformationStr = sparkConf.get( - "spark.shuffle.crypto.cipher.transformation", "AES/CTR/NoPadding") + val key = credentials.getSecretKey(SPARK_IO_TOKEN) + val transformationStr = sparkConf.get(SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoOutputStream(transformationStr, properties, os, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } @@ -68,16 +72,14 @@ private[spark] object CryptoStreamUtils { val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) - val transformationStr = sparkConf.get( - "spark.shuffle.crypto.cipher.transformation", "AES/CTR/NoPadding") + val key = credentials.getSecretKey(SPARK_IO_TOKEN) + val transformationStr = sparkConf.get(SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoInputStream(transformationStr, properties, is, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } /** - * Get Commons-crypto configurations from Spark configurations identified - * by prefix. + * Get Commons-crypto configurations from Spark configurations identified by prefix. */ def toCryptoConf( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index be89e0cb3bbe..590a19d87ffb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -19,18 +19,22 @@ package org.apache.spark.serializer import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.crypto.KeyGenerator +import org.apache.hadoop.security.Credentials import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils} +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.storage._ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** - * Component which configures serialization and compression for various Spark components, including - * automatic selection of which [[Serializer]] to use for shuffles. + * Component which configures serialization, compression and encryption for various Spark + * components, including automatic selection of which [[Serializer]] to use for shuffles. */ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { @@ -62,8 +66,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to encrypt shuffle file encryption - private[this] val enableShuffleFileEncryption = CryptoConf.isShuffleEncryptionEnabled(conf) + // Whether to encrypt IO encryption + private[this] val enableIOEncryption = conf.get(SPARK_IO_ENCRYPTION_ENABLED) /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark @@ -106,18 +110,32 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar } } + /** + * Wrap an input stream for encryption and compression + */ + def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + wrapForCompression(blockId, wrapForEncryption(s)) + } + + /** + * Wrap an output stream for encryption and compression + */ + def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = { + wrapForEncryption(wrapForCompression(blockId, s)) + } + /** * Wrap an input stream for encryption if shuffle encryption is enabled */ def wrapForEncryption(s: InputStream): InputStream = { - if (enableShuffleFileEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ def wrapForEncryption(s: OutputStream): OutputStream = { - if (enableShuffleFileEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s } /** @@ -141,8 +159,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar values: Iterator[T]): Unit = { val byteStream = new BufferedOutputStream(outputStream) val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() - ser.serializeStream(wrapForEncryption(wrapForCompression(blockId, byteStream))).writeAll( - values).close() + ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ @@ -172,7 +189,25 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar val stream = new BufferedInputStream(inputStream) getSerializer(implicitly[ClassTag[T]]) .newInstance() - .deserializeStream(wrapForCompression(blockId, wrapForEncryption(stream))) + .deserializeStream(wrapStream(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] } } + +private[spark] object SerializerManager { + /** + * Setup the cryptographic key used by IO encryption in credentials. The key is generated using + * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. + */ + def initShuffleEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { + if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { + val keyLen = conf.get(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS) + val IOKeyGenAlgorithm = conf.get(SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(IOKeyGenAlgorithm) + keyGen.init(keyLen) + + val IOKey = keyGen.generateKey() + credentials.addSecretKey(SPARK_IO_TOKEN, IOKey.getEncoded) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 008599bc63e9..b9d83495d29b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -53,8 +53,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Wrap the streams for compression and encryption based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - val eis = serializerManager.wrapForEncryption(inputStream) - serializerManager.wrapForCompression(blockId, eis) + serializerManager.wrapStream(blockId, inputStream) } val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 39caa9cf8ed7..c72f28e00cdb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -721,11 +721,9 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val compressStream: OutputStream => OutputStream = - serializerManager.wrapForCompression(blockId, _) - val encryptStream: OutputStream => OutputStream = serializerManager.wrapForEncryption(_) + val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, encryptStream, compressStream, + new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, syncWrites, writeMetrics, blockId) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a621fddbcb7c..a499827ae159 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -39,8 +39,7 @@ private[spark] class DiskBlockObjectWriter( val file: File, serializerInstance: SerializerInstance, bufferSize: Int, - compressStream: OutputStream => OutputStream, - encryptStream: OutputStream => OutputStream, + wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -117,8 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = encryptStream(mcs) - bs = compressStream(bs) + bs = wrapStream(mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 7cab54e07db2..d220ab51d115 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -330,9 +330,7 @@ private[spark] class MemoryStore( redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { val ser = serializerManager.getSerializer(classTag).newInstance() - ser.serializeStream( - serializerManager.wrapForEncryption(serializerManager.wrapForCompression(blockId, - redirectableStream))) + ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index ca3eac47241c..09435281194b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -486,9 +486,8 @@ class ExternalAppendOnlyMap[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val encryptedStream = serializerManager.wrapForEncryption(bufferedStream) - val compressedStream = serializerManager.wrapForCompression(blockId, encryptedStream) - ser.deserializeStream(compressedStream) + val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream) + ser.deserializeStream(wrappedStream) } else { // No more batches left cleanup() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 10154e33204c..2e51e6056cad 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,10 +26,8 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ -import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} @@ -524,10 +522,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val encryptedStream = SparkEnv.get.serializerManager.wrapForEncryption(bufferedStream) - val compressedStream = SparkEnv.get.serializerManager.wrapForCompression(spill.blockId, - encryptedStream) - serInstance.deserializeStream(compressedStream) + val wrappedStream = SparkEnv.get.serializerManager.wrapStream(spill.blockId, bufferedStream) + serInstance.deserializeStream(wrappedStream) } else { // No more batches left cleanup() diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index e68086a02a20..a96cd82382e2 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -86,7 +86,7 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; - private final class CompressStream extends AbstractFunction1 { + private final class WrapStream extends AbstractFunction1 { @Override public OutputStream apply(OutputStream stream) { if (conf.getBoolean("spark.shuffle.compress", true)) { @@ -97,13 +97,6 @@ public OutputStream apply(OutputStream stream) { } } - private static final class EncryptStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - @After public void tearDown() { Utils.deleteRecursively(tempDir); @@ -143,8 +136,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (File) args[1], (SerializerInstance) args[2], (Integer) args[3], - new CompressStream(), - new EncryptStream(), + new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 43576adf5348..33709b454c4c 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -75,14 +75,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - private static final class CompressStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - - private static final class EncryptStream extends AbstractFunction1 { + private static final class WrapStream extends AbstractFunction1 { @Override public OutputStream apply(OutputStream stream) { return stream; @@ -129,8 +122,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (File) args[1], (SerializerInstance) args[2], (Integer) args[3], - new CompressStream(), - new EncryptStream(), + new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index e50d2c1d58f4..a9cf8ff520ed 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -88,14 +88,7 @@ public int compare( private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); - private static final class CompressStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - - private static final class EncryptStream extends AbstractFunction1 { + private static final class WrapStream extends AbstractFunction1 { @Override public OutputStream apply(OutputStream stream) { return stream; @@ -135,8 +128,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (File) args[1], (SerializerInstance) args[2], (Integer) args[3], - new CompressStream(), - new EncryptStream(), + new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] diff --git a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala b/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala index bf16428088a7..7f602fd62caa 100644 --- a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.crypto import java.security.PrivilegedExceptionAction import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.crypto.CryptoConf._ -import org.apache.spark.crypto.CryptoStreamUtils.{COMMONS_CRYPTO_CONF_PREFIX, SPARK_COMMONS_CRYPTO_CONF_PREFIX} - +import org.apache.spark.internal.config._ +import CryptoStreamUtils._ +import org.apache.spark.serializer.SerializerManager private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) @@ -51,7 +52,7 @@ private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() initCredentials(conf, credentials) - assert(credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) === null) + assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null) } }) } @@ -61,12 +62,12 @@ private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { override def run(): Unit = { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() - conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) initCredentials(conf, credentials) - var key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) + var key = credentials.getSecretKey(SPARK_IO_TOKEN) assert(key !== null) val actual = key.length * (java.lang.Byte.SIZE) - assert(actual === DEFAULT_SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS) + assert(actual === 128) } }) } @@ -76,10 +77,10 @@ private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { override def run(): Unit = { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() - conf.set(SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS, 256.toString) - conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + conf.set(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS, 256) + conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) initCredentials(conf, credentials) - var key = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) + var key = credentials.getSecretKey(SPARK_IO_TOKEN) assert(key !== null) val actual = key.length * (java.lang.Byte.SIZE) assert(actual === 256) @@ -92,8 +93,8 @@ private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { override def run(): Unit = { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() - conf.set(SPARK_SHUFFLE_ENCRYPTION_KEY_SIZE_BITS, 328.toString) - conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + conf.set(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS, 328) + conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) val thrown = intercept[IllegalArgumentException] { initCredentials(conf, credentials) } @@ -102,8 +103,8 @@ private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { } private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { - if (CryptoConf.isShuffleEncryptionEnabled(conf)) { - CryptoConf.initSparkShuffleCredentials(conf, credentials) + if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + SerializerManager.initShuffleEncryptionKey(conf, credentials) } } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 735ffba24365..ed9428820ff6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -94,8 +94,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(1).asInstanceOf[File], args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], - compressStream = identity, - encryptStream = identity, + wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 86fa803a8a69..684e978d1186 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -46,8 +46,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -70,8 +69,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -94,8 +92,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.open() writer.close() @@ -108,8 +105,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -127,8 +123,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -144,8 +139,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -162,8 +156,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -183,8 +176,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -202,8 +194,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, os => os, true, - writeMetrics) + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) val segment = writer.commitAndGet() writer.close() assert(segment.length === 0) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 8c70c95a3b84..eaed0889ac36 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -100,7 +100,6 @@ jersey-server-2.22.2.jar jets3t-0.7.1.jar jetty-util-6.1.26.jar jline-2.12.1.jar -jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 3cbb5df511d5..d68a7f462ba7 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -105,7 +105,6 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar -jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 3cfa1ea3c395..346f19767d36 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -105,7 +105,6 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar -jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 1b984905c0c5..6f4695f345a4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -113,7 +113,6 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar -jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 36c1b8de755f..7a86a8bd8884 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -113,7 +113,6 @@ jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar -jna-4.2.2.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/docs/configuration.md b/docs/configuration.md index c2e32f36fb04..a5594dd0d0bf 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -560,36 +560,36 @@ Apart from these, the following properties are also available, and may be useful - spark.shuffle.encryption.enabled + spark.io.encryption.enabled false - Enable shuffle file encryption. + Enable IO encryption. - spark.shuffle.encryption.keySizeBits + spark.io.encryption.keySizeBits 128 - Shuffle file encryption key size in bits. The valid number includes 128, 192 and 256. + IO encryption key size in bits. The valid number includes 128, 192 and 256. - spark.shuffle.encryption.keygen.algorithm + spark.io.encryption.keygen.algorithm HmacSHA1 - The algorithm to generate the key used by shuffle file encryption. The supported algorithms are + The algorithm to generate the key used by IO encryption. The supported algorithms are described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm Name Documentation. - spark.shuffle.crypto.cipher.transformation + spark.io.crypto.cipher.transformation AES/CTR/NoPadding - Cipher transformation for shuffle file encryption. The cipher transformation name is - identical to the transformations described in the Cipher section of the Java Cryptography - Architecture Standard Algorithm Name Documentation. Currently only "AES/CTR/NoPadding" - algorithm is supported. + Cipher transformation for IO encryption. The cipher transformation name is identical to the + transformations described in the Cipher section of the Java Cryptography Architecture + Standard Algorithm Name Documentation. Currently only "AES/CTR/NoPadding" algorithm is + supported. diff --git a/pom.xml b/pom.xml index 5eff64db5256..f90636a87491 100644 --- a/pom.xml +++ b/pom.xml @@ -1844,6 +1844,12 @@ org.apache.commons commons-crypto ${commons-crypto.version} + + + net.java.dev.jna + jna + + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 76a18889412c..580aa7017812 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -48,13 +48,13 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException} -import org.apache.spark.crypto.CryptoConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.util.Utils private[spark] class Client( @@ -1005,8 +1005,8 @@ private[spark] class Client( amContainer.setApplicationACLs( YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) - if (CryptoConf.isShuffleEncryptionEnabled(sparkConf)) { - CryptoConf.initSparkShuffleCredentials(sparkConf, credentials) + if (sparkConf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + SerializerManager.initShuffleEncryptionKey(sparkConf, credentials) } setupSecurityToken(amContainer) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala similarity index 82% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala rename to yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala index 1fe12b95bbe4..db9eaeb4676d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction import java.util.{ArrayList => JArrayList, LinkedList => JLinkedList, UUID} + import scala.runtime.AbstractFunction1 import com.google.common.collect.HashMultiset @@ -37,24 +38,22 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} import org.apache.spark._ -import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils} -import org.apache.spark.crypto.CryptoConf._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.network.buffer.NioManagedBuffer import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializerInstance, -SerializerManager} -import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, - IndexShuffleBlockResolver, RecordingManagedBuffer} +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.serializer._ +import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, UnsafeShuffleWriter} import org.apache.spark.storage._ import org.apache.spark.util.Utils -private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Matchers with - BeforeAndAfterAll with BeforeAndAfterEach { +private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers with + BeforeAndAfterAll with BeforeAndAfterEach { @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockResolver: IndexShuffleBlockResolver = _ @Mock(answer = RETURNS_SMART_NULLS) private[this] var diskBlockManager: DiskBlockManager = _ @@ -102,9 +101,9 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match System.setProperty("SPARK_YARN_MODE", "true") ugi.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - conf.set(SPARK_SHUFFLE_ENCRYPTION_ENABLED, true.toString) + conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) val creds = new Credentials() - CryptoConf.initSparkShuffleCredentials(conf, creds) + SerializerManager.initShuffleEncryptionKey(conf, creds) SparkHadoopUtil.get.addCurrentUserCredentials(creds) } }) @@ -127,32 +126,30 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match conf.set("spark.shuffle.spill.compress", false.toString) Utils.deleteRecursively(tempDir) val leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - if (leakedMemory != 0) { - fail("Test leaked " + leakedMemory + " bytes of managed memory") - } + assert (leakedMemory === 0) } - test("yarn shuffle encryption read and write") { + test("Yarn IO encryption read and write") { ugi.doAs(new PrivilegedExceptionAction[Unit] { override def run(): Unit = { conf.set("spark.shuffle.compress", false.toString) conf.set("spark.shuffle.spill.compress", false.toString) - testYarnShuffleEncryptionWriteRead() + testYarnIOEncryptionWriteRead() } }) } - test("yarn shuffle encryption read and write with shuffle compression enabled") { + test("Yarn IO encryption read and write with shuffle compression enabled") { ugi.doAs(new PrivilegedExceptionAction[Unit] { override def run(): Unit = { conf.set("spark.shuffle.compress", true.toString) conf.set("spark.shuffle.spill.compress", true.toString) - testYarnShuffleEncryptionWriteRead() + testYarnIOEncryptionWriteRead() } }) } - private[this] def testYarnShuffleEncryptionWriteRead(): Unit = { + private[this] def testYarnIOEncryptionWriteRead(): Unit = { val dataToWrite = new JArrayList[Product2[Int, Int]]() for (i <- 0 to NUM_PARTITITONS) { dataToWrite.add((i, i)) @@ -207,7 +204,7 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match val args = invocationOnMock.getArguments new DiskBlockObjectWriter(args(1).asInstanceOf[File], args(2).asInstanceOf[SerializerInstance], - args(3).asInstanceOf[Integer], new CompressStream(), new EncryptStream(), false, + args(3).asInstanceOf[Integer], new WrapStream(), false, args(4).asInstanceOf[ShuffleWriteMetrics], args(0).asInstanceOf[BlockId]) } }) @@ -244,27 +241,20 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) when(blockManager.blockManagerId).thenReturn(localBlockManagerId) - // Create a return function to use for the mocked wrapForCompression method to initial a - // compressed input stream if spark.shuffle.compress is enabled - val compressionFunction = new Answer[InputStream] { + // Create a return function to use for the mocked wrapStream method to initial an + // encrypted and compressed input stream if encryption and compression enabled + val wrapFunction = new Answer[InputStream] { override def answer(invocation: InvocationOnMock): InputStream = { - if (conf.getBoolean("spark.shuffle.compress", false)) { - CompressionCodec.createCodec(conf).compressedInputStream( - invocation.getArguments()(1).asInstanceOf[InputStream]) + val encryptedStream = if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + CryptoStreamUtils.createCryptoInputStream( + invocation.getArguments()(1).asInstanceOf[InputStream], conf) } else { invocation.getArguments()(1).asInstanceOf[InputStream] } - } - } - // Create a return function to use for the mocked wrapForEncryption method to initial a - // encrypted input stream if spark.shuffle.encryption.enabled is enabled - val encryptionFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - if (CryptoConf.isShuffleEncryptionEnabled(conf)) { - CryptoStreamUtils.createCryptoInputStream( - invocation.getArguments()(0).asInstanceOf[InputStream], conf) + if (conf.getBoolean("spark.shuffle.compress", false)) { + CompressionCodec.createCodec(conf).compressedInputStream(encryptedStream) } else { - invocation.getArguments()(0).asInstanceOf[InputStream] + encryptedStream } } } @@ -289,10 +279,8 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(SHUFFLE_ID, mapId, REDUCE_ID) when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) - when(serializerManager.wrapForCompression(meq(shuffleBlockId), - isA(classOf[InputStream]))).thenAnswer(compressionFunction) - when(serializerManager.wrapForEncryption(isA(classOf[InputStream]))).thenAnswer( - encryptionFunction) + when(serializerManager.wrapStream(meq(shuffleBlockId), + isA(classOf[InputStream]))).thenAnswer(wrapFunction) } } @@ -303,10 +291,8 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match (shuffleBlockId, partitionSizesInMergedFile(mapId)) } val mapSizesByExecutorId = Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) - when(mapOutputTracker.getMapSizesByExecutorId(SHUFFLE_ID, REDUCE_ID, REDUCE_ID + 1)).thenReturn - { - mapSizesByExecutorId - } + when(mapOutputTracker.getMapSizesByExecutorId(SHUFFLE_ID, REDUCE_ID, REDUCE_ID + 1)) + .thenReturn(mapSizesByExecutorId) } @throws(classOf[IOException]) @@ -333,22 +319,17 @@ private[spark] class YarnShuffleEncryptionSuite extends SparkFunSuite with Match recordsList } - private[this] final class CompressStream extends AbstractFunction1[OutputStream, OutputStream] { + private[this] final class WrapStream extends AbstractFunction1[OutputStream, OutputStream] { override def apply(stream: OutputStream): OutputStream = { - if (conf.getBoolean("spark.shuffle.compress", false)) { - CompressionCodec.createCodec(conf).compressedOutputStream(stream) + val encryptedStream = if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + CryptoStreamUtils.createCryptoOutputStream(stream, conf) } else { stream } - } - } - - private[this] final class EncryptStream extends AbstractFunction1[OutputStream, OutputStream] { - override def apply(stream: OutputStream): OutputStream = { - if (CryptoConf.isShuffleEncryptionEnabled(conf)) { - CryptoStreamUtils.createCryptoOutputStream(stream, conf) + if (conf.getBoolean("spark.shuffle.compress", false)) { + CompressionCodec.createCodec(conf).compressedOutputStream(encryptedStream) } else { - stream + encryptedStream } } } From 2156514a72786f51cd171d56cec8b7db51294479 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Tue, 23 Aug 2016 04:00:23 +0800 Subject: [PATCH 3/9] Fix code style issues --- .../org/apache/spark/crypto/ShuffleEncryptionSuite.scala | 4 ++-- .../org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala b/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala index 7f602fd62caa..eb7faff26e8e 100644 --- a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.crypto import java.security.PrivilegedExceptionAction import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ -import CryptoStreamUtils._ +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.serializer.SerializerManager private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala index db9eaeb4676d..c5d9cc6a30d3 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala @@ -21,7 +21,6 @@ import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction import java.util.{ArrayList => JArrayList, LinkedList => JLinkedList, UUID} - import scala.runtime.AbstractFunction1 import com.google.common.collect.HashMultiset From 61702dcde1d6fe90fe576cc90de6806bae0aa065 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Wed, 24 Aug 2016 03:06:22 +0800 Subject: [PATCH 4/9] Update patch addressing further comments --- .../org/apache/spark/SecurityManager.scala | 20 ++++++++++++++++ .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/internal/config/package.scala | 12 +++++----- .../spark/security/CryptoStreamUtils.scala | 12 +++++----- .../spark/serializer/SerializerManager.scala | 23 +------------------ .../CryptoStreamUtilsSuite.scala} | 16 ++++++------- .../org/apache/spark/deploy/yarn/Client.scala | 3 +-- .../deploy/yarn/YarnIOEncryptionSuite.scala | 2 +- 8 files changed, 43 insertions(+), 47 deletions(-) rename core/src/test/scala/org/apache/spark/{crypto/ShuffleEncryptionSuite.scala => security/CryptoStreamUtilsSuite.scala} (87%) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index a6550b6ca8c9..c65d7196bf72 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,15 +21,19 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate +import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder +import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -554,4 +558,20 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" + + /** + * Setup the cryptographic key used by IO encryption in credentials. The key is generated using + * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. + */ + def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { + if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { + val keyLen = conf.get(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS) + val IOKeyGenAlgorithm = conf.get(SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(IOKeyGenAlgorithm) + keyGen.init(keyLen) + + val IOKey = keyGen.generateKey() + credentials.addSecretKey(SPARK_IO_TOKEN, IOKey.getEncoded) + } + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8580c94c7304..9c5560dab5a9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -415,7 +415,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") if (_conf.get(SPARK_IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("IO encryption is only supported in Yarn mode, please disable it " + + throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + "by setting spark.io.encryption.enabled to false") } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9a0224e09da2..14f863b10b9a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -124,19 +124,19 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM = ConfigBuilder( - "spark.io.encryption.keygen.algorithm") + private[spark] val SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM = + ConfigBuilder("spark.io.encryption.keygen.algorithm") .stringConf .createWithDefault("HmacSHA1") - private[spark] val SPARK_IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder( - "spark.io.encryption.keySizeBits") + private[spark] val SPARK_IO_ENCRYPTION_KEY_SIZE_BITS = + ConfigBuilder("spark.io.encryption.keySizeBits") .intConf .checkValues(Set(128, 192, 256)) .createWithDefault(128) - private[spark] val SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION = ConfigBuilder( - "spark.io.crypto.cipher.transformation") + private[spark] val SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION = + ConfigBuilder("spark.io.crypto.cipher.transformation") .stringConf .createWithDefaultString("AES/CTR/NoPadding") } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 5ba020c068bc..f62dd2aebdf9 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -39,9 +39,9 @@ private[spark] object CryptoStreamUtils { // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 - // The prefix of Crypto related configurations in Spark configuration. - val SPARK_COMMONS_CRYPTO_CONF_PREFIX = "spark.commons.crypto." - // The prefix for the configurations passing to Commons-crypto library. + // The prefix of IO encryption related configurations in Spark configuration. + val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config." + // The prefix for the configurations passing to Apache Commons Crypto library. val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** @@ -50,7 +50,7 @@ private[spark] object CryptoStreamUtils { def createCryptoOutputStream( os: OutputStream, sparkConf: SparkConf): OutputStream = { - val properties = toCryptoConf(sparkConf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, + val properties = toCryptoConf(sparkConf, SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, COMMONS_CRYPTO_CONF_PREFIX) val iv = createInitializationVector(properties) os.write(iv) @@ -67,7 +67,7 @@ private[spark] object CryptoStreamUtils { def createCryptoInputStream( is: InputStream, sparkConf: SparkConf): InputStream = { - val properties = toCryptoConf(sparkConf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, + val properties = toCryptoConf(sparkConf, SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, COMMONS_CRYPTO_CONF_PREFIX) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) @@ -89,7 +89,7 @@ private[spark] object CryptoStreamUtils { conf.getAll.foreach { case (k, v) => if (k.startsWith(sparkPrefix)) { props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring( - SPARK_COMMONS_CRYPTO_CONF_PREFIX.length()), v) + SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v) } } props diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 590a19d87ffb..e488456950b5 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -19,16 +19,13 @@ package org.apache.spark.serializer import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream} import java.nio.ByteBuffer -import javax.crypto.KeyGenerator -import org.apache.hadoop.security.Credentials import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.security.CryptoStreamUtils -import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.storage._ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -66,7 +63,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to encrypt IO encryption + // Whether to enable IO encryption private[this] val enableIOEncryption = conf.get(SPARK_IO_ENCRYPTION_ENABLED) /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay @@ -193,21 +190,3 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar .asIterator.asInstanceOf[Iterator[T]] } } - -private[spark] object SerializerManager { - /** - * Setup the cryptographic key used by IO encryption in credentials. The key is generated using - * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. - */ - def initShuffleEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS) - val IOKeyGenAlgorithm = conf.get(SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(IOKeyGenAlgorithm) - keyGen.init(keyLen) - - val IOKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, IOKey.getEncoded) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala similarity index 87% rename from core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala rename to core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index eb7faff26e8e..5bb60a0c304b 100644 --- a/core/src/test/scala/org/apache/spark/crypto/ShuffleEncryptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -14,33 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.crypto +package org.apache.spark.security import java.security.PrivilegedExceptionAction import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.security.CryptoStreamUtils._ -import org.apache.spark.serializer.SerializerManager -private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { +private[spark] class CryptoStreamUtilsSuite extends SparkFunSuite { val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) test("Crypto configuration conversion") { - val sparkKey1 = s"${SPARK_COMMONS_CRYPTO_CONF_PREFIX}a.b.c" + val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c" val sparkVal1 = "val1" val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" - val sparkKey2 = SPARK_COMMONS_CRYPTO_CONF_PREFIX.stripSuffix(".") + "A.b.c" + val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c" val sparkVal2 = "val2" val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c" val conf = new SparkConf() conf.set(sparkKey1, sparkVal1) conf.set(sparkKey2, sparkVal2) - val props = CryptoStreamUtils.toCryptoConf(conf, SPARK_COMMONS_CRYPTO_CONF_PREFIX, + val props = CryptoStreamUtils.toCryptoConf(conf, SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, COMMONS_CRYPTO_CONF_PREFIX) assert(props.getProperty(cryptoKey1) === sparkVal1) assert(!props.containsKey(cryptoKey2)) @@ -104,7 +102,7 @@ private[spark] class ShuffleEncryptionSuite extends SparkFunSuite { private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { - SerializerManager.initShuffleEncryptionKey(conf, credentials) + SecurityManager.initIOEncryptionKey(conf, credentials) } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 580aa7017812..77c45a55c010 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,7 +54,6 @@ import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} -import org.apache.spark.serializer.SerializerManager import org.apache.spark.util.Utils private[spark] class Client( @@ -1006,7 +1005,7 @@ private[spark] class Client( YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) if (sparkConf.get(SPARK_IO_ENCRYPTION_ENABLED)) { - SerializerManager.initShuffleEncryptionKey(sparkConf, credentials) + SecurityManager.initIOEncryptionKey(sparkConf, credentials) } setupSecurityToken(amContainer) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala index c5d9cc6a30d3..60cad70d1c15 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala @@ -102,7 +102,7 @@ private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers w override def run(): Unit = { conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) val creds = new Credentials() - SerializerManager.initShuffleEncryptionKey(conf, creds) + SecurityManager.initIOEncryptionKey(conf, creds) SparkHadoopUtil.get.addCurrentUserCredentials(creds) } }) From beb45266872cd52f2a64496056989237477305b6 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Thu, 25 Aug 2016 04:09:02 +0800 Subject: [PATCH 5/9] Rename test --- .../org/apache/spark/SecurityManager.scala | 10 ++++---- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/internal/config/package.scala | 15 ++++++----- .../spark/security/CryptoStreamUtils.scala | 4 +-- .../spark/serializer/SerializerManager.scala | 2 +- .../security/CryptoStreamUtilsSuite.scala | 14 +++++------ .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../IOEncryptionSuite.scala} | 25 ++++++++----------- 8 files changed, 35 insertions(+), 39 deletions(-) rename yarn/src/test/scala/org/apache/spark/{deploy/yarn/YarnIOEncryptionSuite.scala => security/IOEncryptionSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index c65d7196bf72..199365ad925a 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -565,13 +565,13 @@ private[spark] object SecurityManager { */ def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS) - val IOKeyGenAlgorithm = conf.get(SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(IOKeyGenAlgorithm) + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) keyGen.init(keyLen) - val IOKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, IOKey.getEncoded) + val ioKey = keyGen.generateKey() + credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) } } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9c5560dab5a9..e2889a6f80b0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -414,7 +414,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (_conf.get(SPARK_IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { + if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + "by setting spark.io.encryption.enabled to false") } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 14f863b10b9a..f9f12c2feb3f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -120,22 +120,21 @@ package object config { .intConf .createWithDefault(100000) - private[spark] val SPARK_IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") + private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") .booleanConf .createWithDefault(false) - private[spark] val SPARK_IO_ENCRYPTION_KEYGEN_ALGORITHM = + private[spark] val IO_ENCRYPTION_KEYGEN_ALGORITHM = ConfigBuilder("spark.io.encryption.keygen.algorithm") .stringConf .createWithDefault("HmacSHA1") - private[spark] val SPARK_IO_ENCRYPTION_KEY_SIZE_BITS = - ConfigBuilder("spark.io.encryption.keySizeBits") - .intConf - .checkValues(Set(128, 192, 256)) - .createWithDefault(128) + private[spark] val IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder("spark.io.encryption.keySizeBits") + .intConf + .checkValues(Set(128, 192, 256)) + .createWithDefault(128) - private[spark] val SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION = + private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION = ConfigBuilder("spark.io.crypto.cipher.transformation") .stringConf .createWithDefaultString("AES/CTR/NoPadding") diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index f62dd2aebdf9..9c516ae6bcef 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -56,7 +56,7 @@ private[spark] object CryptoStreamUtils { os.write(iv) val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() val key = credentials.getSecretKey(SPARK_IO_TOKEN) - val transformationStr = sparkConf.get(SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION) + val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoOutputStream(transformationStr, properties, os, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } @@ -73,7 +73,7 @@ private[spark] object CryptoStreamUtils { is.read(iv, 0, iv.length) val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() val key = credentials.getSecretKey(SPARK_IO_TOKEN) - val transformationStr = sparkConf.get(SPARK_IO_CRYPTO_CIPHER_TRANSFORMATION) + val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoInputStream(transformationStr, properties, is, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index e488456950b5..f10448fad610 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -64,7 +64,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) // Whether to enable IO encryption - private[this] val enableIOEncryption = conf.get(SPARK_IO_ENCRYPTION_ENABLED) + private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 5bb60a0c304b..4cc74e39bdfc 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.security.CryptoStreamUtils._ -private[spark] class CryptoStreamUtilsSuite extends SparkFunSuite { +class CryptoStreamUtilsSuite extends SparkFunSuite { val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) test("Crypto configuration conversion") { @@ -60,7 +60,7 @@ private[spark] class CryptoStreamUtilsSuite extends SparkFunSuite { override def run(): Unit = { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() - conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) + conf.set(IO_ENCRYPTION_ENABLED, true) initCredentials(conf, credentials) var key = credentials.getSecretKey(SPARK_IO_TOKEN) assert(key !== null) @@ -75,8 +75,8 @@ private[spark] class CryptoStreamUtilsSuite extends SparkFunSuite { override def run(): Unit = { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() - conf.set(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS, 256) - conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) + conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256) + conf.set(IO_ENCRYPTION_ENABLED, true) initCredentials(conf, credentials) var key = credentials.getSecretKey(SPARK_IO_TOKEN) assert(key !== null) @@ -91,8 +91,8 @@ private[spark] class CryptoStreamUtilsSuite extends SparkFunSuite { override def run(): Unit = { val credentials = UserGroupInformation.getCurrentUser.getCredentials() val conf = new SparkConf() - conf.set(SPARK_IO_ENCRYPTION_KEY_SIZE_BITS, 328) - conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) + conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328) + conf.set(IO_ENCRYPTION_ENABLED, true) val thrown = intercept[IllegalArgumentException] { initCredentials(conf, credentials) } @@ -101,7 +101,7 @@ private[spark] class CryptoStreamUtilsSuite extends SparkFunSuite { } private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { - if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + if (conf.get(IO_ENCRYPTION_ENABLED)) { SecurityManager.initIOEncryptionKey(conf, credentials) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 77c45a55c010..2398f0aea316 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1004,7 +1004,7 @@ private[spark] class Client( amContainer.setApplicationACLs( YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) - if (sparkConf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + if (sparkConf.get(IO_ENCRYPTION_ENABLED)) { SecurityManager.initIOEncryptionKey(sparkConf, credentials) } setupSecurityToken(amContainer) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/security/IOEncryptionSuite.scala similarity index 96% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala rename to yarn/src/test/scala/org/apache/spark/security/IOEncryptionSuite.scala index 60cad70d1c15..98de93aeda05 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnIOEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/security/IOEncryptionSuite.scala @@ -14,27 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.yarn +package org.apache.spark.security import java.io._ import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction import java.util.{ArrayList => JArrayList, LinkedList => JLinkedList, UUID} -import scala.runtime.AbstractFunction1 - import com.google.common.collect.HashMultiset import com.google.common.io.ByteStreams import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.junit.Assert.assertEquals -import org.mockito.Mock -import org.mockito.MockitoAnnotations -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer +import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} +import scala.runtime.AbstractFunction1 import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -44,15 +42,14 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.network.buffer.NioManagedBuffer import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer._ import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, UnsafeShuffleWriter} import org.apache.spark.storage._ import org.apache.spark.util.Utils -private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers with - BeforeAndAfterAll with BeforeAndAfterEach { +class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll + with BeforeAndAfterEach { @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockResolver: IndexShuffleBlockResolver = _ @Mock(answer = RETURNS_SMART_NULLS) private[this] var diskBlockManager: DiskBlockManager = _ @@ -87,7 +84,6 @@ private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers w new BaseShuffleHandle(SHUFFLE_ID, NUM_MAPS, dependency) } - // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. private[this] val mapOutputTracker = mock(classOf[MapOutputTracker]) @@ -100,7 +96,7 @@ private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers w System.setProperty("SPARK_YARN_MODE", "true") ugi.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - conf.set(SPARK_IO_ENCRYPTION_ENABLED, true) + conf.set(IO_ENCRYPTION_ENABLED, true) val creds = new Credentials() SecurityManager.initIOEncryptionKey(conf, creds) SparkHadoopUtil.get.addCurrentUserCredentials(creds) @@ -110,6 +106,7 @@ private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers w override def afterAll(): Unit = { SparkEnv.set(null) + System.clearProperty("SPARK_YARN_MODE") } override def beforeEach(): Unit = { @@ -244,7 +241,7 @@ private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers w // encrypted and compressed input stream if encryption and compression enabled val wrapFunction = new Answer[InputStream] { override def answer(invocation: InvocationOnMock): InputStream = { - val encryptedStream = if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + val encryptedStream = if (conf.get(IO_ENCRYPTION_ENABLED)) { CryptoStreamUtils.createCryptoInputStream( invocation.getArguments()(1).asInstanceOf[InputStream], conf) } else { @@ -320,7 +317,7 @@ private[spark] class YarnIOEncryptionSuite extends SparkFunSuite with Matchers w private[this] final class WrapStream extends AbstractFunction1[OutputStream, OutputStream] { override def apply(stream: OutputStream): OutputStream = { - val encryptedStream = if (conf.get(SPARK_IO_ENCRYPTION_ENABLED)) { + val encryptedStream = if (conf.get(IO_ENCRYPTION_ENABLED)) { CryptoStreamUtils.createCryptoOutputStream(stream, conf) } else { stream From 9f958a4847af46de18befaede4d08093fe11416f Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Thu, 25 Aug 2016 08:07:21 +0800 Subject: [PATCH 6/9] Avoid external use for encryption and compression stream --- .../apache/spark/serializer/SerializerManager.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index f10448fad610..86a4464fd86f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -124,28 +124,28 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - def wrapForEncryption(s: InputStream): InputStream = { + private[this] def wrapForEncryption(s: InputStream): InputStream = { if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - def wrapForEncryption(s: OutputStream): OutputStream = { + private[this] def wrapForEncryption(s: OutputStream): OutputStream = { if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s } /** * Wrap an output stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -172,7 +172,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val ser = getSerializer(classTag).newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } From a9a05c5168eede0db26135c0f8f330b451c840ad Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Fri, 26 Aug 2016 03:45:10 +0800 Subject: [PATCH 7/9] Fix some further comments --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/security/CryptoStreamUtils.scala | 23 +++++++++++-------- .../spark/serializer/SerializerManager.scala | 2 +- .../util/collection/ExternalSorter.scala | 2 +- .../security/CryptoStreamUtilsSuite.scala | 3 +-- docs/configuration.md | 2 +- .../yarn}/IOEncryptionSuite.scala | 3 ++- 7 files changed, 20 insertions(+), 17 deletions(-) rename yarn/src/test/scala/org/apache/spark/{security => deploy/yarn}/IOEncryptionSuite.scala (99%) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e2889a6f80b0..2f6212475b19 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -416,7 +416,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + - "by setting spark.io.encryption.enabled to false") + s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") } // "_jobProgressListener" should be set up before creating SparkEnv because when creating diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 9c516ae6bcef..8f15f50bee81 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -26,12 +26,13 @@ import org.apache.hadoop.io.Text import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ /** * A util class for manipulating IO encryption and decryption streams. */ -private[spark] object CryptoStreamUtils { +private[spark] object CryptoStreamUtils extends Logging { /** * Constants and variables for spark IO encryption */ @@ -50,8 +51,7 @@ private[spark] object CryptoStreamUtils { def createCryptoOutputStream( os: OutputStream, sparkConf: SparkConf): OutputStream = { - val properties = toCryptoConf(sparkConf, SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, - COMMONS_CRYPTO_CONF_PREFIX) + val properties = toCryptoConf(sparkConf) val iv = createInitializationVector(properties) os.write(iv) val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() @@ -67,8 +67,7 @@ private[spark] object CryptoStreamUtils { def createCryptoInputStream( is: InputStream, sparkConf: SparkConf): InputStream = { - val properties = toCryptoConf(sparkConf, SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, - COMMONS_CRYPTO_CONF_PREFIX) + val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() @@ -81,13 +80,10 @@ private[spark] object CryptoStreamUtils { /** * Get Commons-crypto configurations from Spark configurations identified by prefix. */ - def toCryptoConf( - conf: SparkConf, - sparkPrefix: String, - cryptoPrefix: String): Properties = { + def toCryptoConf(conf: SparkConf): Properties = { val props = new Properties() conf.getAll.foreach { case (k, v) => - if (k.startsWith(sparkPrefix)) { + if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) { props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring( SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v) } @@ -100,7 +96,14 @@ private[spark] object CryptoStreamUtils { */ private[this] def createInitializationVector(properties: Properties): Array[Byte] = { val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val initialIVStart = System.currentTimeMillis() CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv) + val initialIVFinish = System.currentTimeMillis() + val initialIVTime = initialIVFinish - initialIVStart + if (initialIVTime > 2000) { + logWarning(s"It costs ${initialIVTime} milliseconds to create the Initialization Vector " + + s"used by CryptoStream") + } iv } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 86a4464fd86f..7b1ec6fcbbbf 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -118,7 +118,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * Wrap an output stream for encryption and compression */ def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = { - wrapForEncryption(wrapForCompression(blockId, s)) + wrapForCompression(blockId, wrapForEncryption(s)) } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2e51e6056cad..3579918fac45 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -522,7 +522,7 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val wrappedStream = SparkEnv.get.serializerManager.wrapStream(spill.blockId, bufferedStream) + val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream) serInstance.deserializeStream(wrappedStream) } else { // No more batches left diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 4cc74e39bdfc..81eb907ac7ba 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -38,8 +38,7 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { val conf = new SparkConf() conf.set(sparkKey1, sparkVal1) conf.set(sparkKey2, sparkVal2) - val props = CryptoStreamUtils.toCryptoConf(conf, SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, - COMMONS_CRYPTO_CONF_PREFIX) + val props = CryptoStreamUtils.toCryptoConf(conf) assert(props.getProperty(cryptoKey1) === sparkVal1) assert(!props.containsKey(cryptoKey2)) } diff --git a/docs/configuration.md b/docs/configuration.md index a5594dd0d0bf..f73dea19fae3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -563,7 +563,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.encryption.enabled false - Enable IO encryption. + Enable IO encryption. It only supports YARN mode. diff --git a/yarn/src/test/scala/org/apache/spark/security/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala similarity index 99% rename from yarn/src/test/scala/org/apache/spark/security/IOEncryptionSuite.scala rename to yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala index 98de93aeda05..0e1d23360219 100644 --- a/yarn/src/test/scala/org/apache/spark/security/IOEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.security +package org.apache.spark.deploy.yarn import java.io._ import java.nio.ByteBuffer @@ -42,6 +42,7 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.network.buffer.NioManagedBuffer import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer._ import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, UnsafeShuffleWriter} From a811bf4ecf26c384ec14532708372e7f9d5fda87 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Mon, 29 Aug 2016 08:48:44 +0800 Subject: [PATCH 8/9] Update test cases --- .../spark/internal/config/package.scala | 1 + docs/configuration.md | 16 +- .../spark/deploy/yarn/IOEncryptionSuite.scala | 275 ++---------------- 3 files changed, 30 insertions(+), 262 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index f9f12c2feb3f..ebce07c1e3b3 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -136,6 +136,7 @@ package object config { private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION = ConfigBuilder("spark.io.crypto.cipher.transformation") + .internal() .stringConf .createWithDefaultString("AES/CTR/NoPadding") } diff --git a/docs/configuration.md b/docs/configuration.md index f73dea19fae3..d0c76aaad0b3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -563,35 +563,25 @@ Apart from these, the following properties are also available, and may be useful spark.io.encryption.enabled false - Enable IO encryption. It only supports YARN mode. + Enable IO encryption. Only supported in YARN mode. spark.io.encryption.keySizeBits 128 - IO encryption key size in bits. The valid number includes 128, 192 and 256. + IO encryption key size in bits. Supported values are 128, 192 and 256. spark.io.encryption.keygen.algorithm HmacSHA1 - The algorithm to generate the key used by IO encryption. The supported algorithms are + The algorithm to use when generating the IO encryption key. The supported algorithms are described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm Name Documentation. - - spark.io.crypto.cipher.transformation - AES/CTR/NoPadding - - Cipher transformation for IO encryption. The cipher transformation name is identical to the - transformations described in the Cipher section of the Java Cryptography Architecture - Standard Algorithm Name Documentation. Currently only "AES/CTR/NoPadding" algorithm is - supported. - - #### Spark UI diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala index 0e1d23360219..d8bf2e3d281c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala @@ -17,83 +17,28 @@ package org.apache.spark.deploy.yarn import java.io._ -import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.security.PrivilegedExceptionAction -import java.util.{ArrayList => JArrayList, LinkedList => JLinkedList, UUID} +import java.util.UUID -import com.google.common.collect.HashMultiset -import com.google.common.io.ByteStreams import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.junit.Assert.assertEquals -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Matchers.{eq => meq, _} -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer +import org.mockito.MockitoAnnotations import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} -import scala.runtime.AbstractFunction1 import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.internal.config._ -import org.apache.spark.io.CompressionCodec -import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} -import org.apache.spark.network.buffer.NioManagedBuffer -import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer._ -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, UnsafeShuffleWriter} import org.apache.spark.storage._ -import org.apache.spark.util.Utils class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll with BeforeAndAfterEach { - @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockManager: BlockManager = _ - @Mock(answer = RETURNS_SMART_NULLS) private[this] var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private[this] var diskBlockManager: DiskBlockManager = _ - @Mock(answer = RETURNS_SMART_NULLS) private[this] var serializerManager: SerializerManager = _ - @Mock(answer = RETURNS_SMART_NULLS) private[this] var taskContext: TaskContext = _ - @Mock( - answer = RETURNS_SMART_NULLS) private[this] var shuffleDep: ShuffleDependency[Int, Int, Int] = _ - - private[this] val NUM_MAPS = 1 - private[this] val NUM_PARTITITONS = 4 - private[this] val REDUCE_ID = 1 - private[this] val SHUFFLE_ID = 0 + private[this] val blockId = new TempShuffleBlockId(UUID.randomUUID()) private[this] val conf = new SparkConf() - private[this] val memoryManager = new TestMemoryManager(conf) - private[this] val hashPartitioner = new HashPartitioner(NUM_PARTITITONS) - private[this] val serializer = new KryoSerializer(conf) - private[this] val spillFilesCreated = new JLinkedList[File]() - private[this] val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - private[this] val taskMetrics = new TaskMetrics() - - private[this] var tempDir: File = _ - private[this] var mergedOutputFile: File = _ - private[this] var partitionSizesInMergedFile: Array[Long] = _ private[this] val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) - - // Create a mocked shuffle handle to pass into HashShuffleReader. - private[this] val shuffleHandle = { - val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) - when(dependency.serializer).thenReturn(serializer) - when(dependency.aggregator).thenReturn(None) - when(dependency.keyOrdering).thenReturn(None) - new BaseShuffleHandle(SHUFFLE_ID, NUM_MAPS, dependency) - } - - // Make a mocked MapOutputTracker for the shuffle reader to use to determine what - // shuffle data to read. - private[this] val mapOutputTracker = mock(classOf[MapOutputTracker]) - private[this] val sparkEnv = mock(classOf[SparkEnv]) + private[this] val serializer = new KryoSerializer(conf) override def beforeAll(): Unit = { - when(sparkEnv.conf).thenReturn(conf) - SparkEnv.set(sparkEnv) - System.setProperty("SPARK_YARN_MODE", "true") ugi.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { @@ -113,22 +58,18 @@ class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterA override def beforeEach(): Unit = { super.beforeEach() MockitoAnnotations.initMocks(this) - tempDir = Utils.createTempDir() - mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) } override def afterEach(): Unit = { super.afterEach() conf.set("spark.shuffle.compress", false.toString) conf.set("spark.shuffle.spill.compress", false.toString) - Utils.deleteRecursively(tempDir) - val leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - assert (leakedMemory === 0) } - test("Yarn IO encryption read and write") { + test("IO encryption read and write") { ugi.doAs(new PrivilegedExceptionAction[Unit] { override def run(): Unit = { + conf.set(IO_ENCRYPTION_ENABLED, true) conf.set("spark.shuffle.compress", false.toString) conf.set("spark.shuffle.spill.compress", false.toString) testYarnIOEncryptionWriteRead() @@ -136,9 +77,10 @@ class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterA }) } - test("Yarn IO encryption read and write with shuffle compression enabled") { + test("IO encryption read and write with shuffle compression enabled") { ugi.doAs(new PrivilegedExceptionAction[Unit] { override def run(): Unit = { + conf.set(IO_ENCRYPTION_ENABLED, true) conf.set("spark.shuffle.compress", true.toString) conf.set("spark.shuffle.spill.compress", true.toString) testYarnIOEncryptionWriteRead() @@ -147,187 +89,22 @@ class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterA } private[this] def testYarnIOEncryptionWriteRead(): Unit = { - val dataToWrite = new JArrayList[Product2[Int, Int]]() - for (i <- 0 to NUM_PARTITITONS) { - dataToWrite.add((i, i)) - } - val shuffleWriter = createWriter() - shuffleWriter.write(dataToWrite.iterator()) - shuffleWriter.stop(true) - - val shuffleReader = createReader() - val iter = shuffleReader.read() - val recordsList = new JArrayList[(Int, Int)]() - while (iter.hasNext) { - recordsList.add(iter.next().asInstanceOf[(Int, Int)]) - } - - assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(recordsList)) - } - - private[this] def createWriter(): UnsafeShuffleWriter[Int, Int] = { - initialMocksForWriter() - new UnsafeShuffleWriter[Int, Int]( - blockManager, - blockResolver, - taskMemoryManager, - new SerializedShuffleHandle[Int, Int](SHUFFLE_ID, NUM_MAPS, shuffleDep), - 0, // map id - taskContext, - conf - ) - } - - private[this] def createReader(): BlockStoreShuffleReader[Int, Int] = { - initialMocksForReader() - - new BlockStoreShuffleReader( - shuffleHandle, - REDUCE_ID, - REDUCE_ID + 1, - TaskContext.empty(), - serializerManager, - blockManager, - mapOutputTracker) - } - - private[this] def initialMocksForWriter(): Unit = { - when(blockManager.diskBlockManager).thenReturn(diskBlockManager) - when(blockManager.conf).thenReturn(conf) - when(blockManager.getDiskWriter(any(classOf[BlockId]), any(classOf[File]), - any(classOf[SerializerInstance]), anyInt, any(classOf[ShuffleWriteMetrics]))).thenAnswer( - new Answer[DiskBlockObjectWriter]() { - override def answer(invocationOnMock: InvocationOnMock): DiskBlockObjectWriter = { - val args = invocationOnMock.getArguments - new DiskBlockObjectWriter(args(1).asInstanceOf[File], - args(2).asInstanceOf[SerializerInstance], - args(3).asInstanceOf[Integer], new WrapStream(), false, - args(4).asInstanceOf[ShuffleWriteMetrics], args(0).asInstanceOf[BlockId]) - } - }) - - when(blockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile) - doAnswer(new Answer[Unit]() { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp = invocationOnMock.getArguments()(3) - mergedOutputFile.delete() - tmp.asInstanceOf[File].renameTo(mergedOutputFile) - } - }).when(blockResolver).writeIndexFileAndCommit(anyInt(), anyInt(), any(classOf[Array[Long]]), - any(classOf[File])) - - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer[(TempShuffleBlockId, File)]() { - override def answer(invocationOnMock: InvocationOnMock): (TempShuffleBlockId, File) = { - val blockId = new TempShuffleBlockId(UUID.randomUUID()) - val file = File.createTempFile("spillFile", ".spill", tempDir) - spillFilesCreated.add(file) - (blockId, file) - } - }) - - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(shuffleDep.serializer).thenReturn(serializer) - when(shuffleDep.partitioner).thenReturn(hashPartitioner) - when(taskContext.taskMetrics()).thenReturn(taskMetrics) - } - - private[this] def initialMocksForReader(): Unit = { - // Setup the mocked BlockManager to return RecordingManagedBuffers. - val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) - when(blockManager.blockManagerId).thenReturn(localBlockManagerId) - - // Create a return function to use for the mocked wrapStream method to initial an - // encrypted and compressed input stream if encryption and compression enabled - val wrapFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - val encryptedStream = if (conf.get(IO_ENCRYPTION_ENABLED)) { - CryptoStreamUtils.createCryptoInputStream( - invocation.getArguments()(1).asInstanceOf[InputStream], conf) - } else { - invocation.getArguments()(1).asInstanceOf[InputStream] - } - if (conf.getBoolean("spark.shuffle.compress", false)) { - CompressionCodec.createCodec(conf).compressedInputStream(encryptedStream) - } else { - encryptedStream - } - } - } - var startOffset = 0L - for (mapId <- 0 until NUM_PARTITITONS) { - val partitionSize: Long = partitionSizesInMergedFile(mapId) - if (partitionSize > 0) { - val bytes = new Array[Byte](partitionSize.toInt) - var in: InputStream = new FileInputStream(mergedOutputFile) - ByteStreams.skipFully(in, startOffset) - in = new LimitedInputStream(in, partitionSize) - try { - in.read(bytes) - } finally { - in.close() - } - // Create a ManagedBuffer with the shuffle data. - val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(bytes)) - val managedBuffer = new RecordingManagedBuffer(nioBuffer) - startOffset += partitionSizesInMergedFile(mapId) - // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to - // fetch shuffle data. - val shuffleBlockId = ShuffleBlockId(SHUFFLE_ID, mapId, REDUCE_ID) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) - when(serializerManager.wrapStream(meq(shuffleBlockId), - isA(classOf[InputStream]))).thenAnswer(wrapFunction) - } - } - - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until NUM_PARTITITONS).map { mapId => - val shuffleBlockId = ShuffleBlockId(SHUFFLE_ID, mapId, REDUCE_ID) - (shuffleBlockId, partitionSizesInMergedFile(mapId)) - } - val mapSizesByExecutorId = Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) - when(mapOutputTracker.getMapSizesByExecutorId(SHUFFLE_ID, REDUCE_ID, REDUCE_ID + 1)) - .thenReturn(mapSizesByExecutorId) - } - - @throws(classOf[IOException]) - private def readRecordsFromFile: JArrayList[(Any, Any)] = { - val recordsList: JArrayList[(Any, Any)] = new JArrayList[(Any, Any)] - var startOffset = 0L - for (mapId <- 0 until NUM_PARTITITONS) { - val partitionSize: Long = partitionSizesInMergedFile(mapId) - if (partitionSize > 0) { - var in: InputStream = new FileInputStream(mergedOutputFile) - ByteStreams.skipFully(in, startOffset) - in = new LimitedInputStream(in, partitionSize) - val recordsStream: DeserializationStream = serializer.newInstance.deserializeStream(in) - val records: Iterator[(Any, Any)] = recordsStream.asKeyValueIterator - while (records.hasNext) { - val record: (Any, Any) = records.next - assertEquals(mapId, hashPartitioner.getPartition(record._1)) - recordsList.add(record) - } - recordsStream.close - startOffset += partitionSize - } - } - recordsList - } - - private[this] final class WrapStream extends AbstractFunction1[OutputStream, OutputStream] { - override def apply(stream: OutputStream): OutputStream = { - val encryptedStream = if (conf.get(IO_ENCRYPTION_ENABLED)) { - CryptoStreamUtils.createCryptoOutputStream(stream, conf) - } else { - stream - } - if (conf.getBoolean("spark.shuffle.compress", false)) { - CompressionCodec.createCodec(conf).compressedOutputStream(encryptedStream) - } else { - encryptedStream - } - } + val plainStr = "hello world" + val outputStream = new ByteArrayOutputStream() + val serializerManager = new SerializerManager(serializer, conf) + val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream) + wrappedOutputStream.write(plainStr.getBytes(StandardCharsets.UTF_8)) + wrappedOutputStream.close() + + val encryptedBytes = outputStream.toByteArray + val encryptedStr = new String(encryptedBytes) + assert (plainStr !== encryptedStr) + + val inputStream = new ByteArrayInputStream(encryptedBytes) + val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) + val decryptedBytes = new Array[Byte](1024) + val len = wrappedInputStream.read(decryptedBytes) + val decryptedStr = new String(decryptedBytes, 0, len, StandardCharsets.UTF_8) + assert (decryptedStr === plainStr) } } From 928a59bc4566ec40e6caeccbc628369f050c31c9 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Tue, 30 Aug 2016 02:15:23 +0800 Subject: [PATCH 9/9] Address further comments --- .../org/apache/spark/deploy/yarn/IOEncryptionSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala index d8bf2e3d281c..1c60315b21ae 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala @@ -22,7 +22,6 @@ import java.security.PrivilegedExceptionAction import java.util.UUID import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.mockito.MockitoAnnotations import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} import org.apache.spark._ @@ -57,7 +56,6 @@ class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterA override def beforeEach(): Unit = { super.beforeEach() - MockitoAnnotations.initMocks(this) } override def afterEach(): Unit = { @@ -98,13 +96,13 @@ class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterA val encryptedBytes = outputStream.toByteArray val encryptedStr = new String(encryptedBytes) - assert (plainStr !== encryptedStr) + assert(plainStr !== encryptedStr) val inputStream = new ByteArrayInputStream(encryptedBytes) val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) val decryptedBytes = new Array[Byte](1024) val len = wrappedInputStream.read(decryptedBytes) val decryptedStr = new String(decryptedBytes, 0, len, StandardCharsets.UTF_8) - assert (decryptedStr === plainStr) + assert(decryptedStr === plainStr) } }