Skip to content

Commit 3f8927f

Browse files
committed
[Spark-5682] Add spark encrypted shuffle by using chimera lib
1 parent a38e23c commit 3f8927f

File tree

10 files changed

+342
-9
lines changed

10 files changed

+342
-9
lines changed

core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,11 @@
359359
<artifactId>py4j</artifactId>
360360
<version>0.8.2.1</version>
361361
</dependency>
362+
<dependency>
363+
<groupId>com.intel.chimera</groupId>
364+
<artifactId>chimera</artifactId>
365+
<version>0.0.1</version>
366+
</dependency>
362367
</dependencies>
363368
<build>
364369
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.crypto
18+
19+
import org.apache.hadoop.io.Text
20+
21+
/**
22+
* Constant variables
23+
*/
24+
private[spark] object CommonConfigurationKeys {
25+
val SPARK_SHUFFLE_TOKEN: Text = new Text("SPARK_SHUFFLE_TOKEN")
26+
val SPARK_ENCRYPTED_INTERMEDIATE_DATA_BUFFER_KB: String = "spark.job" +
27+
".encrypted-intermediate-data.buffer.kb"
28+
val DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA_BUFFER_KB: Int = 128
29+
val SPARK_ENCRYPTED_INTERMEDIATE_DATA: String =
30+
"spark.job.encrypted-intermediate-data"
31+
val DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA: Boolean = false
32+
val SPARK_ENCRYPTED_INTERMEDIATE_DATA_KEY_SIZE_BITS: String =
33+
"spark.job.encrypted-intermediate-data-key-size-bits"
34+
val DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA_KEY_SIZE_BITS: Int = 128
35+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.crypto
18+
19+
import java.security.NoSuchAlgorithmException
20+
import javax.crypto.{KeyGenerator, SecretKey}
21+
22+
import org.apache.hadoop.security.Credentials
23+
24+
import org.apache.spark.crypto.CommonConfigurationKeys._
25+
import org.apache.spark.SparkConf
26+
27+
/**
28+
* CryptoConf is a class for Crypto configuration
29+
*/
30+
private[spark] case class CryptoConf(enabled: Boolean = false) {
31+
32+
}
33+
34+
private[spark] object CryptoConf {
35+
def parse(sparkConf: SparkConf): CryptoConf = {
36+
val enabled = if (sparkConf != null) {
37+
sparkConf.getBoolean("spark.encrypted.shuffle", false)
38+
} else {
39+
false
40+
}
41+
new CryptoConf(enabled)
42+
}
43+
44+
def initSparkShuffleCredentials(conf:SparkConf, credentials: Credentials) {
45+
if (credentials.getSecretKey(SPARK_SHUFFLE_TOKEN) == null) {
46+
var keyGen: KeyGenerator = null
47+
try {
48+
val SHUFFLE_KEY_LENGTH: Int = 64
49+
var keyLen: Int = if (conf.getBoolean(SPARK_ENCRYPTED_INTERMEDIATE_DATA,
50+
DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA) == true) {
51+
conf.getInt(SPARK_ENCRYPTED_INTERMEDIATE_DATA_KEY_SIZE_BITS,
52+
DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA_KEY_SIZE_BITS)
53+
} else {
54+
SHUFFLE_KEY_LENGTH
55+
}
56+
val SHUFFLE_KEYGEN_ALGORITHM = "HmacSHA1";
57+
keyGen = KeyGenerator.getInstance(SHUFFLE_KEYGEN_ALGORITHM)
58+
keyGen.init(keyLen)
59+
} catch {
60+
case e: NoSuchAlgorithmException => throw new RuntimeException("Error generating " +
61+
"shuffle secret key")
62+
}
63+
val shuffleKey: SecretKey = keyGen.generateKey
64+
credentials.addSecretKey(SPARK_SHUFFLE_TOKEN, shuffleKey.getEncoded)
65+
}
66+
}
67+
}
68+

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ private[spark] class BlockManager(
651651
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
652652
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
653653
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
654-
writeMetrics)
654+
writeMetrics).setSparkConf(conf)
655655
}
656656

657657
/**

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ package org.apache.spark.storage
2020
import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
2121
import java.nio.channels.FileChannel
2222

23-
import org.apache.spark.Logging
24-
import org.apache.spark.serializer.{SerializationStream, Serializer}
23+
import org.apache.hadoop.mapreduce.security.TokenCache
24+
import com.intel.chimera.{CipherSuite, CryptoCodec, CryptoOutputStream}
25+
26+
import org.apache.spark.{Logging,SparkConf}
27+
import org.apache.spark.crypto.CommonConfigurationKeys._
28+
import org.apache.spark.crypto.CryptoConf
29+
import org.apache.spark.deploy.SparkHadoopUtil
2530
import org.apache.spark.executor.ShuffleWriteMetrics
31+
import org.apache.spark.serializer.{SerializationStream, Serializer}
2632

2733
/**
2834
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -123,19 +129,47 @@ private[spark] class DiskBlockObjectWriter(
123129
*/
124130
private var numRecordsWritten = 0
125131

132+
private var sparkConf:SparkConf = null
133+
126134
override def open(): BlockObjectWriter = {
127135
if (hasBeenClosed) {
128136
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
129137
}
130138
fos = new FileOutputStream(file, true)
131-
ts = new TimeTrackingOutputStream(fos)
139+
val cryptoConf = CryptoConf.parse(sparkConf)
140+
if (cryptoConf.enabled) {
141+
val cryptoCodec: CryptoCodec = CryptoCodec.getInstance(CipherSuite.AES_CTR_NOPADDING)
142+
val bufferSize: Int = sparkConf.getInt(
143+
SPARK_ENCRYPTED_INTERMEDIATE_DATA_BUFFER_KB,
144+
DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA_BUFFER_KB) * 1024
145+
val iv: Array[Byte] = createInitializationVector(cryptoCodec)
146+
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
147+
var key: Array[Byte] = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN)
148+
fos.write(iv)
149+
val cos = new CryptoOutputStream(fos, cryptoCodec,
150+
bufferSize, key, iv, iv.length)
151+
ts = new TimeTrackingOutputStream(cos)
152+
} else {
153+
ts = new TimeTrackingOutputStream(fos)
154+
}
132155
channel = fos.getChannel()
133156
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
134157
objOut = serializer.newInstance().serializeStream(bs)
135158
initialized = true
136159
this
137160
}
138161

162+
def createInitializationVector(cryptoCodec: CryptoCodec): Array[Byte] = {
163+
val iv: Array[Byte] = new Array[Byte](cryptoCodec.getCipherSuite.getAlgorithmBlockSize)
164+
cryptoCodec.generateSecureRandom(iv)
165+
iv
166+
}
167+
168+
def setSparkConf(sparkConfVal: SparkConf): DiskBlockObjectWriter = {
169+
sparkConf = sparkConfVal
170+
this
171+
}
172+
139173
override def close() {
140174
if (initialized) {
141175
if (syncWrites) {

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.{InputStream, IOException}
20+
import java.io.InputStream
2121
import java.util.concurrent.LinkedBlockingQueue
2222

2323
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
2424
import scala.util.{Failure, Success, Try}
2525

26+
import com.intel.chimera.{CipherSuite, CryptoCodec, CryptoInputStream}
27+
2628
import org.apache.spark.{Logging, TaskContext}
29+
import org.apache.spark.crypto.CommonConfigurationKeys._
30+
import org.apache.spark.crypto.CryptoConf
31+
import org.apache.spark.deploy.SparkHadoopUtil
32+
import org.apache.spark.network.buffer.ManagedBuffer
2733
import org.apache.spark.network.BlockTransferService
2834
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
29-
import org.apache.spark.network.buffer.ManagedBuffer
3035
import org.apache.spark.serializer.Serializer
3136
import org.apache.spark.util.{CompletionIterator, Utils}
3237

@@ -296,8 +301,27 @@ final class ShuffleBlockFetcherIterator(
296301
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
297302
// not exist, SPARK-4085). In that case, we should propagate the right exception so
298303
// the scheduler gets a FetchFailedException.
304+
// is0:InputStream
299305
Try(buf.createInputStream()).map { is0 =>
300-
val is = blockManager.wrapForCompression(blockId, is0)
306+
var is: InputStream = null
307+
val sparkConf = blockManager.conf
308+
val cryptoConf = CryptoConf.parse(sparkConf)
309+
if (cryptoConf.enabled) {
310+
val cryptoCodec: CryptoCodec = CryptoCodec.getInstance(CipherSuite.AES_CTR_NOPADDING)
311+
val bufferSize: Int = sparkConf.getInt(
312+
SPARK_ENCRYPTED_INTERMEDIATE_DATA_BUFFER_KB,
313+
DEFAULT_SPARK_ENCRYPTED_INTERMEDIATE_DATA_BUFFER_KB) * 1024
314+
val iv: Array[Byte] = new Array[Byte](16)
315+
is0.read(iv, 0, iv.length)
316+
val streamOffset: Long = iv.length
317+
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
318+
var key: Array[Byte] = credentials.getSecretKey(SPARK_SHUFFLE_TOKEN)
319+
var cos = new CryptoInputStream(is0, cryptoCodec, bufferSize, key,
320+
iv, streamOffset)
321+
is = blockManager.wrapForCompression(blockId, cos)
322+
} else {
323+
is = blockManager.wrapForCompression(blockId, is0)
324+
}
301325
val iter = serializer.newInstance().deserializeStream(is).asIterator
302326
CompletionIterator[Any, Iterator[Any]](iter, {
303327
// Once the iterator is exhausted, release the buffer and set currentResult to null
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.crypto
18+
19+
import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
20+
import java.security.SecureRandom
21+
22+
import org.apache.spark.Logging
23+
import org.scalatest.FunSuite
24+
import com.intel.chimera.{CryptoCodec, CryptoInputStream, CryptoOutputStream, JceAesCtrCryptoCodec}
25+
26+
/**
27+
* test JceAesCtrCryptoCodec
28+
*/
29+
class JceAesCtrCryptoCodecSuite extends FunSuite with Logging {
30+
31+
test("TestJceAesCtrCryptoCodecSuite"){
32+
val random: SecureRandom = new SecureRandom
33+
val dataLen: Int = 10000000
34+
val inputData: Array[Byte] = new Array[Byte](dataLen)
35+
val outputData: Array[Byte] = new Array[Byte](dataLen)
36+
random.nextBytes(inputData)
37+
// encrypt
38+
val codec: CryptoCodec = new JceAesCtrCryptoCodec()
39+
val aos: ByteArrayOutputStream = new ByteArrayOutputStream
40+
val bos: BufferedOutputStream = new BufferedOutputStream(aos)
41+
val key: Array[Byte] = new Array[Byte](16)
42+
val iv: Array[Byte] = new Array[Byte](16)
43+
random.nextBytes(key)
44+
random.nextBytes(iv)
45+
46+
val cos: CryptoOutputStream = new CryptoOutputStream(bos, codec, 1024, key, iv)
47+
cos.write(inputData, 0, inputData.length)
48+
cos.flush
49+
// decrypt
50+
val cis: CryptoInputStream = new CryptoInputStream(new ByteArrayInputStream(aos.toByteArray),
51+
codec, 1024, key, iv)
52+
var readLen: Int = 0
53+
var outOffset: Int = 0
54+
while (readLen < dataLen) {
55+
val n: Int = cis.read(outputData, outOffset, outputData.length - outOffset)
56+
if (n >= 0) {
57+
readLen += n
58+
outOffset += n
59+
}
60+
}
61+
var i: Int = 0
62+
for(i <- 0 until dataLen )
63+
{
64+
if (inputData(i) != outputData(i)) {
65+
logInfo(s"decrypt failed:$i")
66+
}
67+
}
68+
}
69+
}
70+
71+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.crypto
18+
19+
import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
20+
import java.security.SecureRandom
21+
22+
import org.apache.spark.Logging
23+
import org.scalatest.FunSuite
24+
import com.intel.chimera.{CryptoCodec, CryptoInputStream, CryptoOutputStream, OpensslAesCtrCryptoCodec}
25+
26+
/**
27+
* Test OpensslAesCtrCryptoCodecSuite
28+
*/
29+
class OpensslAesCtrCryptoCodecSuite extends FunSuite with Logging {
30+
31+
test("TestOpensslAesCtrCryptoCodecSuite"){
32+
val random: SecureRandom = new SecureRandom
33+
val dataLen: Int = 10000000
34+
val inputData: Array[Byte] = new Array[Byte](dataLen)
35+
val outputData: Array[Byte] = new Array[Byte](dataLen)
36+
random.nextBytes(inputData)
37+
// encrypt
38+
val codec: CryptoCodec = new OpensslAesCtrCryptoCodec()
39+
val aos: ByteArrayOutputStream = new ByteArrayOutputStream
40+
val bos: BufferedOutputStream = new BufferedOutputStream(aos)
41+
val key: Array[Byte] = new Array[Byte](16)
42+
val iv: Array[Byte] = new Array[Byte](16)
43+
random.nextBytes(key)
44+
random.nextBytes(iv)
45+
46+
val cos: CryptoOutputStream = new CryptoOutputStream(bos, codec, 1024, key, iv)
47+
cos.write(inputData, 0, inputData.length)
48+
cos.flush
49+
// decrypt
50+
val cis: CryptoInputStream = new CryptoInputStream(new ByteArrayInputStream(aos.toByteArray),
51+
codec, 1024, key, iv)
52+
var readLen: Int = 0
53+
var outOffset: Int = 0
54+
while (readLen < dataLen) {
55+
val n: Int = cis.read(outputData, outOffset, outputData.length - outOffset)
56+
if (n >= 0) {
57+
readLen += n
58+
outOffset += n
59+
}
60+
}
61+
var i: Int = 0
62+
for(i <- 0 until dataLen )
63+
{
64+
if (inputData(i) != outputData(i)) {
65+
logInfo(s"decrypt failed:$i")
66+
}
67+
}
68+
}
69+
}
70+

0 commit comments

Comments
 (0)