Skip to content

Commit 236e967

Browse files
Marcelo Vanzincmonkey
authored andcommitted
[SPARK-19520][STREAMING] Do not encrypt data written to the WAL.
Spark's I/O encryption uses an ephemeral key for each driver instance. So driver B cannot decrypt data written by driver A since it doesn't have the correct key. The write ahead log is used for recovery, thus needs to be readable by a different driver. So it cannot be encrypted by Spark's I/O encryption code. The BlockManager APIs used by the WAL code to write the data automatically encrypt data, so changes are needed so that callers can to opt out of encryption. Aside from that, the "putBytes" API in the BlockManager does not do encryption, so a separate situation arised where the WAL would write unencrypted data to the BM and, when those blocks were read, decryption would fail. So the WAL code needs to ask the BM to encrypt that data when encryption is enabled; this code is not optimal since it results in a (temporary) second copy of the data block in memory, but should be OK for now until a more performant solution is added. The non-encryption case should not be affected. Tested with new unit tests, and by running streaming apps that do recovery using the WAL data with I/O encryption turned on. Author: Marcelo Vanzin <[email protected]> Closes apache#16862 from vanzin/SPARK-19520.
1 parent 8446355 commit 236e967

File tree

8 files changed

+120
-30
lines changed

8 files changed

+120
-30
lines changed

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ import org.apache.spark.util.Utils
184184

185185
private[spark] class SecurityManager(
186186
sparkConf: SparkConf,
187-
ioEncryptionKey: Option[Array[Byte]] = None)
187+
val ioEncryptionKey: Option[Array[Byte]] = None)
188188
extends Logging with SecretKeyHolder {
189189

190190
import SecurityManager._

core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,26 @@ private[spark] class SerializerManager(
171171
}
172172

173173
/** Serializes into a chunked byte buffer. */
174-
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
175-
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
174+
def dataSerialize[T: ClassTag](
175+
blockId: BlockId,
176+
values: Iterator[T],
177+
allowEncryption: Boolean = true): ChunkedByteBuffer = {
178+
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]],
179+
allowEncryption = allowEncryption)
176180
}
177181

178182
/** Serializes into a chunked byte buffer. */
179183
def dataSerializeWithExplicitClassTag(
180184
blockId: BlockId,
181185
values: Iterator[_],
182-
classTag: ClassTag[_]): ChunkedByteBuffer = {
186+
classTag: ClassTag[_],
187+
allowEncryption: Boolean = true): ChunkedByteBuffer = {
183188
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
184189
val byteStream = new BufferedOutputStream(bbos)
185190
val autoPick = !blockId.isInstanceOf[StreamBlockId]
186191
val ser = getSerializer(classTag, autoPick).newInstance()
187-
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
192+
val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream
193+
ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close()
188194
bbos.toChunkedByteBuffer
189195
}
190196

@@ -194,13 +200,15 @@ private[spark] class SerializerManager(
194200
*/
195201
def dataDeserializeStream[T](
196202
blockId: BlockId,
197-
inputStream: InputStream)
203+
inputStream: InputStream,
204+
maybeEncrypted: Boolean = true)
198205
(classTag: ClassTag[T]): Iterator[T] = {
199206
val stream = new BufferedInputStream(inputStream)
200207
val autoPick = !blockId.isInstanceOf[StreamBlockId]
208+
val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream
201209
getSerializer(classTag, autoPick)
202210
.newInstance()
203-
.deserializeStream(wrapStream(blockId, stream))
211+
.deserializeStream(wrapForCompression(blockId, decrypted))
204212
.asIterator.asInstanceOf[Iterator[T]]
205213
}
206214
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import scala.reflect.ClassTag
2828
import scala.util.Random
2929
import scala.util.control.NonFatal
3030

31+
import com.google.common.io.ByteStreams
32+
3133
import org.apache.spark._
3234
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
3335
import org.apache.spark.internal.Logging
@@ -38,6 +40,7 @@ import org.apache.spark.network.netty.SparkTransportConf
3840
import org.apache.spark.network.shuffle.ExternalShuffleClient
3941
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
4042
import org.apache.spark.rpc.RpcEnv
43+
import org.apache.spark.security.CryptoStreamUtils
4144
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
4245
import org.apache.spark.shuffle.ShuffleManager
4346
import org.apache.spark.storage.memory._
@@ -752,15 +755,43 @@ private[spark] class BlockManager(
752755
/**
753756
* Put a new block of serialized bytes to the block manager.
754757
*
758+
* @param encrypt If true, asks the block manager to encrypt the data block before storing,
759+
* when I/O encryption is enabled. This is required for blocks that have been
760+
* read from unencrypted sources, since all the BlockManager read APIs
761+
* automatically do decryption.
755762
* @return true if the block was stored or false if an error occurred.
756763
*/
757764
def putBytes[T: ClassTag](
758765
blockId: BlockId,
759766
bytes: ChunkedByteBuffer,
760767
level: StorageLevel,
761-
tellMaster: Boolean = true): Boolean = {
768+
tellMaster: Boolean = true,
769+
encrypt: Boolean = false): Boolean = {
762770
require(bytes != null, "Bytes is null")
763-
doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster)
771+
772+
val bytesToStore =
773+
if (encrypt && securityManager.ioEncryptionKey.isDefined) {
774+
try {
775+
val data = bytes.toByteBuffer
776+
val in = new ByteBufferInputStream(data, true)
777+
val byteBufOut = new ByteBufferOutputStream(data.remaining())
778+
val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf,
779+
securityManager.ioEncryptionKey.get)
780+
try {
781+
ByteStreams.copy(in, out)
782+
} finally {
783+
in.close()
784+
out.close()
785+
}
786+
new ChunkedByteBuffer(byteBufOut.toByteBuffer)
787+
} finally {
788+
bytes.dispose()
789+
}
790+
} else {
791+
bytes
792+
}
793+
794+
doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster)
764795
}
765796

766797
/**

docs/streaming-programming-guide.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,9 @@ To run a Spark Streaming applications, you need to have the following.
20172017
`spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and
20182018
`spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See
20192019
[Spark Streaming Configuration](configuration.html#spark-streaming) for more details.
2020+
Note that Spark will not encrypt data written to the write ahead log when I/O encryption is
2021+
enabled. If encryption of the write ahead log data is desired, it should be stored in a file
2022+
system that supports encryption natively.
20202023

20212024
- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming
20222025
application to process data as fast as it is being received, the receivers can be rate limited

streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark._
2727
import org.apache.spark.rdd.BlockRDD
2828
import org.apache.spark.storage.{BlockId, StorageLevel}
2929
import org.apache.spark.streaming.util._
30-
import org.apache.spark.util.SerializableConfiguration
30+
import org.apache.spark.util._
3131
import org.apache.spark.util.io.ChunkedByteBuffer
3232

3333
/**
@@ -158,13 +158,16 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
158158
logInfo(s"Read partition data of $this from write ahead log, record handle " +
159159
partition.walRecordHandle)
160160
if (storeInBlockManager) {
161-
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel)
161+
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel,
162+
encrypt = true)
162163
logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
163164
dataRead.rewind()
164165
}
165166
serializerManager
166167
.dataDeserializeStream(
167-
blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
168+
blockId,
169+
new ChunkedByteBuffer(dataRead).toInputStream(),
170+
maybeEncrypted = false)(elementClassTag)
168171
.asInstanceOf[Iterator[T]]
169172
}
170173

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ private[streaming] class BlockManagerBasedBlockHandler(
8787
putResult
8888
case ByteBufferBlock(byteBuffer) =>
8989
blockManager.putBytes(
90-
blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true)
90+
blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true,
91+
encrypt = true)
9192
case o =>
9293
throw new SparkException(
9394
s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}")
@@ -175,10 +176,11 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
175176
val serializedBlock = block match {
176177
case ArrayBufferBlock(arrayBuffer) =>
177178
numRecords = Some(arrayBuffer.size.toLong)
178-
serializerManager.dataSerialize(blockId, arrayBuffer.iterator)
179+
serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false)
179180
case IteratorBlock(iterator) =>
180181
val countIterator = new CountingIterator(iterator)
181-
val serializedBlock = serializerManager.dataSerialize(blockId, countIterator)
182+
val serializedBlock = serializerManager.dataSerialize(blockId, countIterator,
183+
allowEncryption = false)
182184
numRecords = countIterator.count
183185
serializedBlock
184186
case ByteBufferBlock(byteBuffer) =>
@@ -193,7 +195,8 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
193195
blockId,
194196
serializedBlock,
195197
effectiveStorageLevel,
196-
tellMaster = true)
198+
tellMaster = true,
199+
encrypt = true)
197200
if (!putSucceeded) {
198201
throw new SparkException(
199202
s"Could not store $blockId to block manager with storage level $storageLevel")

streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ import org.scalatest.concurrent.Eventually._
3232
import org.apache.spark._
3333
import org.apache.spark.broadcast.BroadcastManager
3434
import org.apache.spark.internal.Logging
35+
import org.apache.spark.internal.config._
3536
import org.apache.spark.memory.StaticMemoryManager
3637
import org.apache.spark.network.netty.NettyBlockTransferService
3738
import org.apache.spark.rpc.RpcEnv
3839
import org.apache.spark.scheduler.LiveListenerBus
40+
import org.apache.spark.security.CryptoStreamUtils
3941
import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
4042
import org.apache.spark.shuffle.sort.SortShuffleManager
4143
import org.apache.spark.storage._
@@ -44,7 +46,7 @@ import org.apache.spark.streaming.util._
4446
import org.apache.spark.util.{ManualClock, Utils}
4547
import org.apache.spark.util.io.ChunkedByteBuffer
4648

47-
class ReceivedBlockHandlerSuite
49+
abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
4850
extends SparkFunSuite
4951
with BeforeAndAfter
5052
with Matchers
@@ -57,14 +59,22 @@ class ReceivedBlockHandlerSuite
5759
val conf = new SparkConf()
5860
.set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
5961
.set("spark.app.id", "streaming-test")
62+
.set(IO_ENCRYPTION_ENABLED, enableEncryption)
63+
val encryptionKey =
64+
if (enableEncryption) {
65+
Some(CryptoStreamUtils.createKey(conf))
66+
} else {
67+
None
68+
}
69+
6070
val hadoopConf = new Configuration()
6171
val streamId = 1
62-
val securityMgr = new SecurityManager(conf)
72+
val securityMgr = new SecurityManager(conf, encryptionKey)
6373
val broadcastManager = new BroadcastManager(true, conf, securityMgr)
6474
val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
6575
val shuffleManager = new SortShuffleManager(conf)
6676
val serializer = new KryoSerializer(conf)
67-
var serializerManager = new SerializerManager(serializer, conf)
77+
var serializerManager = new SerializerManager(serializer, conf, encryptionKey)
6878
val manualClock = new ManualClock
6979
val blockManagerSize = 10000000
7080
val blockManagerBuffer = new ArrayBuffer[BlockManager]()
@@ -164,7 +174,9 @@ class ReceivedBlockHandlerSuite
164174
val bytes = reader.read(fileSegment)
165175
reader.close()
166176
serializerManager.dataDeserializeStream(
167-
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
177+
generateBlockId(),
178+
new ChunkedByteBuffer(bytes).toInputStream(),
179+
maybeEncrypted = false)(ClassTag.Any).toList
168180
}
169181
loggedData shouldEqual data
170182
}
@@ -208,6 +220,8 @@ class ReceivedBlockHandlerSuite
208220
sparkConf.set("spark.storage.unrollMemoryThreshold", "512")
209221
// spark.storage.unrollFraction set to 0.4 for BlockManager
210222
sparkConf.set("spark.storage.unrollFraction", "0.4")
223+
224+
sparkConf.set(IO_ENCRYPTION_ENABLED, enableEncryption)
211225
// Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll
212226
blockManager = createBlockManager(12000, sparkConf)
213227

@@ -343,7 +357,7 @@ class ReceivedBlockHandlerSuite
343357
}
344358

345359
def dataToByteBuffer(b: Seq[String]) =
346-
serializerManager.dataSerialize(generateBlockId, b.iterator)
360+
serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false)
347361

348362
val blocks = data.grouped(10).toSeq
349363

@@ -418,3 +432,6 @@ class ReceivedBlockHandlerSuite
418432
private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong)
419433
}
420434

435+
class ReceivedBlockHandlerSuite extends BaseReceivedBlockHandlerSuite(false)
436+
437+
class ReceivedBlockHandlerWithEncryptionSuite extends BaseReceivedBlockHandlerSuite(true)

streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
2424
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
2525

2626
import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
27+
import org.apache.spark.internal.config._
2728
import org.apache.spark.serializer.SerializerManager
2829
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
2930
import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
@@ -45,6 +46,7 @@ class WriteAheadLogBackedBlockRDDSuite
4546

4647
override def beforeEach(): Unit = {
4748
super.beforeEach()
49+
initSparkContext()
4850
dir = Utils.createTempDir()
4951
}
5052

@@ -56,22 +58,33 @@ class WriteAheadLogBackedBlockRDDSuite
5658
}
5759
}
5860

59-
override def beforeAll(): Unit = {
60-
super.beforeAll()
61-
sparkContext = new SparkContext(conf)
62-
blockManager = sparkContext.env.blockManager
63-
serializerManager = sparkContext.env.serializerManager
61+
override def afterAll(): Unit = {
62+
try {
63+
stopSparkContext()
64+
} finally {
65+
super.afterAll()
66+
}
6467
}
6568

66-
override def afterAll(): Unit = {
69+
private def initSparkContext(_conf: Option[SparkConf] = None): Unit = {
70+
if (sparkContext == null) {
71+
sparkContext = new SparkContext(_conf.getOrElse(conf))
72+
blockManager = sparkContext.env.blockManager
73+
serializerManager = sparkContext.env.serializerManager
74+
}
75+
}
76+
77+
private def stopSparkContext(): Unit = {
6778
// Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests.
6879
try {
69-
sparkContext.stop()
80+
if (sparkContext != null) {
81+
sparkContext.stop()
82+
}
7083
System.clearProperty("spark.driver.port")
7184
blockManager = null
7285
serializerManager = null
7386
} finally {
74-
super.afterAll()
87+
sparkContext = null
7588
}
7689
}
7790

@@ -106,6 +119,17 @@ class WriteAheadLogBackedBlockRDDSuite
106119
numPartitions = 5, numPartitionsInBM = 0, numPartitionsInWAL = 5, testStoreInBM = true)
107120
}
108121

122+
test("read data in block manager and WAL with encryption on") {
123+
stopSparkContext()
124+
try {
125+
val testConf = conf.clone().set(IO_ENCRYPTION_ENABLED, true)
126+
initSparkContext(Some(testConf))
127+
testRDD(numPartitions = 5, numPartitionsInBM = 3, numPartitionsInWAL = 2)
128+
} finally {
129+
stopSparkContext()
130+
}
131+
}
132+
109133
/**
110134
* Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager
111135
* and the rest to a write ahead log, and then reading it all back using the RDD.
@@ -226,7 +250,8 @@ class WriteAheadLogBackedBlockRDDSuite
226250
require(blockData.size === blockIds.size)
227251
val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
228252
val segments = blockData.zip(blockIds).map { case (data, id) =>
229-
writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer)
253+
writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false)
254+
.toByteBuffer)
230255
}
231256
writer.close()
232257
segments

0 commit comments

Comments
 (0)