Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ import org.apache.spark.util.Utils

private[spark] class SecurityManager(
sparkConf: SparkConf,
ioEncryptionKey: Option[Array[Byte]] = None)
val ioEncryptionKey: Option[Array[Byte]] = None)
extends Logging with SecretKeyHolder {

import SecurityManager._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,26 @@ private[spark] class SerializerManager(
}

/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
def dataSerialize[T: ClassTag](
blockId: BlockId,
values: Iterator[T],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]],
allowEncryption = allowEncryption)
}

/** Serializes into a chunked byte buffer. */
def dataSerializeWithExplicitClassTag(
blockId: BlockId,
values: Iterator[_],
classTag: ClassTag[_]): ChunkedByteBuffer = {
classTag: ClassTag[_],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(classTag, autoPick).newInstance()
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream
ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close()
bbos.toChunkedByteBuffer
}

Expand All @@ -194,13 +200,15 @@ private[spark] class SerializerManager(
*/
def dataDeserializeStream[T](
blockId: BlockId,
inputStream: InputStream)
inputStream: InputStream,
maybeEncrypted: Boolean = true)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream
getSerializer(classTag, autoPick)
.newInstance()
.deserializeStream(wrapStream(blockId, stream))
.deserializeStream(wrapForCompression(blockId, decrypted))
.asIterator.asInstanceOf[Iterator[T]]
}
}
35 changes: 33 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import scala.reflect.ClassTag
import scala.util.Random
import scala.util.control.NonFatal

import com.google.common.io.ByteStreams

import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.internal.Logging
Expand All @@ -38,6 +40,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage.memory._
Expand Down Expand Up @@ -752,15 +755,43 @@ private[spark] class BlockManager(
/**
* Put a new block of serialized bytes to the block manager.
*
* @param encrypt If true, asks the block manager to encrypt the data block before storing,
* when I/O encryption is enabled. This is required for blocks that have been
* read from unencrypted sources, since all the BlockManager read APIs
* automatically do decryption.
* @return true if the block was stored or false if an error occurred.
*/
def putBytes[T: ClassTag](
blockId: BlockId,
bytes: ChunkedByteBuffer,
level: StorageLevel,
tellMaster: Boolean = true): Boolean = {
tellMaster: Boolean = true,
encrypt: Boolean = false): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its worth documenting this param. At first I was going to suggest that it should be called allowEncryption like the other one, but I realize its more complicated than that. Maybe something like

If true, the given bytes should be encrypted before they are stored. Note that in most cases, the given bytes will already be encrypted if encryption is on. An important exception to this is with the streaming WAL. Since the WAL does not support encryption, those bytes are generated un-encrypted. But we still encrypt those bytes before storing in the block manager.

Maybe too wordy but I think its worth documenting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

require(bytes != null, "Bytes is null")
doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster)

val bytesToStore =
if (encrypt && securityManager.ioEncryptionKey.isDefined) {
try {
val data = bytes.toByteBuffer
val in = new ByteBufferInputStream(data, true)
val byteBufOut = new ByteBufferOutputStream(data.remaining())
val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf,
securityManager.ioEncryptionKey.get)
try {
ByteStreams.copy(in, out)
} finally {
in.close()
out.close()
}
new ChunkedByteBuffer(byteBufOut.toByteBuffer)
} finally {
bytes.dispose()
}
} else {
bytes
}

doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster)
}

/**
Expand Down
3 changes: 3 additions & 0 deletions docs/streaming-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,9 @@ To run a Spark Streaming applications, you need to have the following.
`spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and
`spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See
[Spark Streaming Configuration](configuration.html#spark-streaming) for more details.
Note that Spark will not encrypt data written to the write ahead log when I/O encryption is
enabled. If encryption of the write ahead log data is desired, it should be stored in a file
system that supports encryption natively.

- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming
application to process data as fast as it is being received, the receivers can be rate limited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark._
import org.apache.spark.rdd.BlockRDD
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.streaming.util._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer

/**
Expand Down Expand Up @@ -158,13 +158,16 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
logInfo(s"Read partition data of $this from write ahead log, record handle " +
partition.walRecordHandle)
if (storeInBlockManager) {
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel)
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why encrypt should be true here? In the following codes, it just reads the block using maybeEncrypted = false.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is explained in the summary.

The code that uses maybeEncrypted = false is deserializing data read from the WAL. This code is adding the block to the block manager, which later is read with getBlockFromBlockManager which calls blockManager.get which does decryption automatically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

encrypt = true)
logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
dataRead.rewind()
}
serializerManager
.dataDeserializeStream(
blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
blockId,
new ChunkedByteBuffer(dataRead).toInputStream(),
maybeEncrypted = false)(elementClassTag)
.asInstanceOf[Iterator[T]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ private[streaming] class BlockManagerBasedBlockHandler(
putResult
case ByteBufferBlock(byteBuffer) =>
blockManager.putBytes(
blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true)
blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true,
encrypt = true)
case o =>
throw new SparkException(
s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}")
Expand Down Expand Up @@ -175,10 +176,11 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
val serializedBlock = block match {
case ArrayBufferBlock(arrayBuffer) =>
numRecords = Some(arrayBuffer.size.toLong)
serializerManager.dataSerialize(blockId, arrayBuffer.iterator)
serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false)
case IteratorBlock(iterator) =>
val countIterator = new CountingIterator(iterator)
val serializedBlock = serializerManager.dataSerialize(blockId, countIterator)
val serializedBlock = serializerManager.dataSerialize(blockId, countIterator,
allowEncryption = false)
numRecords = countIterator.count
serializedBlock
case ByteBufferBlock(byteBuffer) =>
Expand All @@ -193,7 +195,8 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
blockId,
serializedBlock,
effectiveStorageLevel,
tellMaster = true)
tellMaster = true,
encrypt = true)
if (!putSucceeded) {
throw new SparkException(
s"Could not store $blockId to block manager with storage level $storageLevel")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.memory.StaticMemoryManager
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.storage._
Expand All @@ -44,7 +46,7 @@ import org.apache.spark.streaming.util._
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer

class ReceivedBlockHandlerSuite
abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
extends SparkFunSuite
with BeforeAndAfter
with Matchers
Expand All @@ -57,14 +59,22 @@ class ReceivedBlockHandlerSuite
val conf = new SparkConf()
.set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
.set("spark.app.id", "streaming-test")
.set(IO_ENCRYPTION_ENABLED, enableEncryption)
val encryptionKey =
if (enableEncryption) {
Some(CryptoStreamUtils.createKey(conf))
} else {
None
}

val hadoopConf = new Configuration()
val streamId = 1
val securityMgr = new SecurityManager(conf)
val securityMgr = new SecurityManager(conf, encryptionKey)
val broadcastManager = new BroadcastManager(true, conf, securityMgr)
val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
val shuffleManager = new SortShuffleManager(conf)
val serializer = new KryoSerializer(conf)
var serializerManager = new SerializerManager(serializer, conf)
var serializerManager = new SerializerManager(serializer, conf, encryptionKey)
val manualClock = new ManualClock
val blockManagerSize = 10000000
val blockManagerBuffer = new ArrayBuffer[BlockManager]()
Expand Down Expand Up @@ -164,7 +174,9 @@ class ReceivedBlockHandlerSuite
val bytes = reader.read(fileSegment)
reader.close()
serializerManager.dataDeserializeStream(
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
generateBlockId(),
new ChunkedByteBuffer(bytes).toInputStream(),
maybeEncrypted = false)(ClassTag.Any).toList
}
loggedData shouldEqual data
}
Expand Down Expand Up @@ -208,6 +220,8 @@ class ReceivedBlockHandlerSuite
sparkConf.set("spark.storage.unrollMemoryThreshold", "512")
// spark.storage.unrollFraction set to 0.4 for BlockManager
sparkConf.set("spark.storage.unrollFraction", "0.4")

sparkConf.set(IO_ENCRYPTION_ENABLED, enableEncryption)
// Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll
blockManager = createBlockManager(12000, sparkConf)

Expand Down Expand Up @@ -343,7 +357,7 @@ class ReceivedBlockHandlerSuite
}

def dataToByteBuffer(b: Seq[String]) =
serializerManager.dataSerialize(generateBlockId, b.iterator)
serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false)

val blocks = data.grouped(10).toSeq

Expand Down Expand Up @@ -418,3 +432,6 @@ class ReceivedBlockHandlerSuite
private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong)
}

class ReceivedBlockHandlerSuite extends BaseReceivedBlockHandlerSuite(false)

class ReceivedBlockHandlerWithEncryptionSuite extends BaseReceivedBlockHandlerSuite(true)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.internal.config._
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
Expand All @@ -45,6 +46,7 @@ class WriteAheadLogBackedBlockRDDSuite

override def beforeEach(): Unit = {
super.beforeEach()
initSparkContext()
dir = Utils.createTempDir()
}

Expand All @@ -56,22 +58,33 @@ class WriteAheadLogBackedBlockRDDSuite
}
}

override def beforeAll(): Unit = {
super.beforeAll()
sparkContext = new SparkContext(conf)
blockManager = sparkContext.env.blockManager
serializerManager = sparkContext.env.serializerManager
override def afterAll(): Unit = {
try {
stopSparkContext()
} finally {
super.afterAll()
}
}

override def afterAll(): Unit = {
private def initSparkContext(_conf: Option[SparkConf] = None): Unit = {
if (sparkContext == null) {
sparkContext = new SparkContext(_conf.getOrElse(conf))
blockManager = sparkContext.env.blockManager
serializerManager = sparkContext.env.serializerManager
}
}

private def stopSparkContext(): Unit = {
// Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests.
try {
sparkContext.stop()
if (sparkContext != null) {
sparkContext.stop()
}
System.clearProperty("spark.driver.port")
blockManager = null
serializerManager = null
} finally {
super.afterAll()
sparkContext = null
}
}

Expand Down Expand Up @@ -106,6 +119,17 @@ class WriteAheadLogBackedBlockRDDSuite
numPartitions = 5, numPartitionsInBM = 0, numPartitionsInWAL = 5, testStoreInBM = true)
}

test("read data in block manager and WAL with encryption on") {
stopSparkContext()
try {
val testConf = conf.clone().set(IO_ENCRYPTION_ENABLED, true)
initSparkContext(Some(testConf))
testRDD(numPartitions = 5, numPartitionsInBM = 3, numPartitionsInWAL = 2)
} finally {
stopSparkContext()
}
}

/**
* Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager
* and the rest to a write ahead log, and then reading it all back using the RDD.
Expand Down Expand Up @@ -226,7 +250,8 @@ class WriteAheadLogBackedBlockRDDSuite
require(blockData.size === blockIds.size)
val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
val segments = blockData.zip(blockIds).map { case (data, id) =>
writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer)
writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false)
.toByteBuffer)
}
writer.close()
segments
Expand Down