Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-crypto</artifactId>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public UnsafeSorterSpillReader(
final BufferedInputStream bs =
new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes);
try {
this.in = serializerManager.wrapForCompression(blockId, bs);
this.in = serializerManager.wrapStream(blockId, bs);
this.din = new DataInputStream(this.in);
numRecords = numRecordsRemaining = din.readInt();
} catch (IOException e) {
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.cert.X509Certificate
import javax.crypto.KeyGenerator
import javax.net.ssl._

import com.google.common.hash.HashCodes
import com.google.common.io.Files
import org.apache.hadoop.io.Text
import org.apache.hadoop.security.Credentials

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
import org.apache.spark.security.CryptoStreamUtils._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -554,4 +558,20 @@ private[spark] object SecurityManager {

// key used to store the spark secret in the Hadoop UGI
val SECRET_LOOKUP_KEY = "sparkCookie"

/**
* Setup the cryptographic key used by IO encryption in credentials. The key is generated using
* [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
*/
def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
keyGen.init(keyLen)

val ioKey = keyGen.generateKey()
credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
}
}
}
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat,
WholeTextFileInputFormat}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
Expand Down Expand Up @@ -413,6 +414,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}

if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
}

// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,24 @@ package object config {
private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks")
.intConf
.createWithDefault(100000)

private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled")
.booleanConf
.createWithDefault(false)

private[spark] val IO_ENCRYPTION_KEYGEN_ALGORITHM =
ConfigBuilder("spark.io.encryption.keygen.algorithm")
.stringConf
.createWithDefault("HmacSHA1")

private[spark] val IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder("spark.io.encryption.keySizeBits")
.intConf
.checkValues(Set(128, 192, 256))
.createWithDefault(128)

private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION =
ConfigBuilder("spark.io.crypto.cipher.transformation")
.internal()
.stringConf
.createWithDefaultString("AES/CTR/NoPadding")
}
109 changes: 109 additions & 0 deletions core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.security

import java.io.{InputStream, OutputStream}
import java.util.Properties
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}

import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._
import org.apache.hadoop.io.Text

import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._

/**
* A util class for manipulating IO encryption and decryption streams.
*/
private[spark] object CryptoStreamUtils extends Logging {
/**
* Constants and variables for spark IO encryption
*/
val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")

// The initialization vector length in bytes.
val IV_LENGTH_IN_BYTES = 16
// The prefix of IO encryption related configurations in Spark configuration.
val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config."
// The prefix for the configurations passing to Apache Commons Crypto library.
val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto."

/**
* Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption.
*/
def createCryptoOutputStream(
os: OutputStream,
sparkConf: SparkConf): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
os.write(iv)
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
}

/**
* Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption.
*/
def createCryptoInputStream(
is: InputStream,
sparkConf: SparkConf): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
}

/**
* Get Commons-crypto configurations from Spark configurations identified by prefix.
*/
def toCryptoConf(conf: SparkConf): Properties = {
val props = new Properties()
conf.getAll.foreach { case (k, v) =>
if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) {
props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring(
SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v)
}
}
props
}

/**
* This method to generate an IV (Initialization Vector) using secure random.
*/
private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
val initialIVStart = System.currentTimeMillis()
CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv)
Copy link
Member

@zsxwing zsxwing Aug 25, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add a warning log to log how long it takes when this line is too slow (e.g., more than 2 seconds)? Sometimes, it may take several seconds to collect enough entropy.

val initialIVFinish = System.currentTimeMillis()
val initialIVTime = initialIVFinish - initialIVStart
if (initialIVTime > 2000) {
logWarning(s"It costs ${initialIVTime} milliseconds to create the Initialization Vector " +
s"used by CryptoStream")
}
iv
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag

import org.apache.spark.SparkConf
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.storage._
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

/**
* Component which configures serialization and compression for various Spark components, including
* automatic selection of which [[Serializer]] to use for shuffles.
* Component which configures serialization, compression and encryption for various Spark
* components, including automatic selection of which [[Serializer]] to use for shuffles.
*/
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {

Expand Down Expand Up @@ -61,6 +63,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)

// Whether to enable IO encryption
private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)

/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
Expand Down Expand Up @@ -102,17 +107,45 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
}
}

/**
* Wrap an input stream for encryption and compression
*/
def wrapStream(blockId: BlockId, s: InputStream): InputStream = {
wrapForCompression(blockId, wrapForEncryption(s))
}

/**
* Wrap an output stream for encryption and compression
*/
def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
wrapForCompression(blockId, wrapForEncryption(s))
}

/**
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
}

/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
}

/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}

/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}

Expand All @@ -123,7 +156,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
values: Iterator[T]): Unit = {
val byteStream = new BufferedOutputStream(outputStream)
val ser = getSerializer(implicitly[ClassTag[T]]).newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
}

/** Serializes into a chunked byte buffer. */
Expand All @@ -139,7 +172,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val ser = getSerializer(classTag).newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}

Expand All @@ -153,7 +186,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
val stream = new BufferedInputStream(inputStream)
getSerializer(implicitly[ClassTag[T]])
.newInstance()
.deserializeStream(wrapForCompression(blockId, stream))
.deserializeStream(wrapStream(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ private[spark] class BlockStoreShuffleReader[K, C](
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

// Wrap the streams for compression based on configuration
// Wrap the streams for compression and encryption based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapForCompression(blockId, inputStream)
serializerManager.wrapStream(blockId, inputStream)
}

val serializerInstance = dep.serializer.newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,10 +721,9 @@ private[spark] class BlockManager(
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val compressStream: OutputStream => OutputStream =
serializerManager.wrapForCompression(blockId, _)
val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
syncWrites, writeMetrics, blockId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[spark] class DiskBlockObjectWriter(
val file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
wrapStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
Expand Down Expand Up @@ -115,7 +115,8 @@ private[spark] class DiskBlockObjectWriter(
initialize()
initialized = true
}
bs = compressStream(mcs)

bs = wrapStream(mcs)
objOut = serializerInstance.serializeStream(bs)
streamOpen = true
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ private[spark] class MemoryStore(
redirectableStream.setOutputStream(bbos)
val serializationStream: SerializationStream = {
val ser = serializerManager.getSerializer(classTag).newInstance()
ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream))
}

// Request enough memory to begin unrolling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ class ExternalAppendOnlyMap[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream)
ser.deserializeStream(compressedStream)
val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream)
ser.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
Expand Down
Loading