diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index 056505ef53356..64fdb32a67ada 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -159,15 +159,21 @@ public void close() throws IOException { // accurately report the errors when they happen. RuntimeException error = null; byte[] dummy = new byte[8]; - try { - doCipherOp(encryptor, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); + if (encryptor != null) { + try { + doCipherOp(Cipher.ENCRYPT_MODE, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + encryptor = null; } - try { - doCipherOp(decryptor, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); + if (decryptor != null) { + try { + doCipherOp(Cipher.DECRYPT_MODE, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + decryptor = null; } random.close(); @@ -189,11 +195,11 @@ byte[] rawResponse(byte[] challenge) { } private byte[] decrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(decryptor, in, false); + return doCipherOp(Cipher.DECRYPT_MODE, in, false); } private byte[] encrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(encryptor, in, false); + return doCipherOp(Cipher.ENCRYPT_MODE, in, false); } private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) @@ -205,11 +211,13 @@ private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) byte[] iv = new byte[conf.ivLength()]; System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); - encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + this.encryptor = _encryptor; - decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + this.decryptor = _decryptor; } /** @@ -241,29 +249,52 @@ private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int k return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); } - private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal) + private byte[] doCipherOp(int mode, byte[] in, boolean isFinal) throws GeneralSecurityException { - Preconditions.checkState(cipher != null); + CryptoCipher cipher; + switch (mode) { + case Cipher.ENCRYPT_MODE: + cipher = encryptor; + break; + case Cipher.DECRYPT_MODE: + cipher = decryptor; + break; + default: + throw new IllegalArgumentException(String.valueOf(mode)); + } - int scale = 1; - while (true) { - int size = in.length * scale; - byte[] buffer = new byte[size]; - try { - int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) - : cipher.update(in, 0, in.length, buffer, 0); - if (outSize != buffer.length) { - byte[] output = new byte[outSize]; - System.arraycopy(buffer, 0, output, 0, output.length); - return output; - } else { - return buffer; + Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error."); + + try { + int scale = 1; + while (true) { + int size = in.length * scale; + byte[] buffer = new byte[size]; + try { + int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) + : cipher.update(in, 0, in.length, buffer, 0); + if (outSize != buffer.length) { + byte[] output = new byte[outSize]; + System.arraycopy(buffer, 0, output, 0, output.length); + return output; + } else { + return buffer; + } + } catch (ShortBufferException e) { + // Try again with a bigger buffer. + scale *= 2; } - } catch (ShortBufferException e) { - // Try again with a bigger buffer. - scale *= 2; } + } catch (InternalError ie) { + // SPARK-25535. The commons-cryto library will throw InternalError if something goes wrong, + // and leave bad state behind in the Java wrappers, so it's not safe to use them afterwards. + if (mode == Cipher.ENCRYPT_MODE) { + this.encryptor = null; + } else { + this.decryptor = null; + } + throw ie; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 0b674cc620231..1e0d27c027683 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -110,10 +110,12 @@ public void addToChannel(Channel ch) throws IOException { static class EncryptionHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; + private boolean isCipherValid; EncryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); cos = cipher.createOutputStream(byteChannel); + isCipherValid = true; } @Override @@ -124,36 +126,61 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) @VisibleForTesting EncryptedMessage createEncryptedMessage(Object msg) { - return new EncryptedMessage(cos, msg, byteChannel); + return new EncryptedMessage(this, cos, msg, byteChannel); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { try { - cos.close(); + if (isCipherValid) { + cos.close(); + } } finally { super.close(ctx, promise); } } + + /** + * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher + * after an error occurs. + */ + void reportError() { + this.isCipherValid = false; + } + + boolean isCipherValid() { + return isCipherValid; + } } private static class DecryptionHandler extends ChannelInboundHandlerAdapter { private final CryptoInputStream cis; private final ByteArrayReadableChannel byteChannel; + private boolean isCipherValid; DecryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayReadableChannel(); cis = cipher.createInputStream(byteChannel); + isCipherValid = true; } @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + if (!isCipherValid) { + throw new IOException("Cipher is in invalid state."); + } byteChannel.feedData((ByteBuf) data); byte[] decryptedData = new byte[byteChannel.readableBytes()]; int offset = 0; while (offset < decryptedData.length) { - offset += cis.read(decryptedData, offset, decryptedData.length - offset); + // SPARK-25535: workaround for CRYPTO-141. + try { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } catch (InternalError ie) { + isCipherValid = false; + throw ie; + } } ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); @@ -162,7 +189,9 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { - cis.close(); + if (isCipherValid) { + cis.close(); + } } finally { super.channelInactive(ctx); } @@ -175,8 +204,9 @@ static class EncryptedMessage extends AbstractFileRegion { private final ByteBuf buf; private final FileRegion region; private final long count; + private final CryptoOutputStream cos; + private final EncryptionHandler handler; private long transferred; - private CryptoOutputStream cos; // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data @@ -186,9 +216,14 @@ static class EncryptedMessage extends AbstractFileRegion { private ByteBuffer currentEncrypted; - EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + EncryptedMessage( + EncryptionHandler handler, + CryptoOutputStream cos, + Object msg, + ByteArrayWritableChannel ch) { Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, "Unrecognized message type: %s", msg.getClass().getName()); + this.handler = handler; this.isByteBuf = msg instanceof ByteBuf; this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; @@ -288,6 +323,9 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep } private void encryptMore() throws IOException { + if (!handler.isCipherValid()) { + throw new IOException("Cipher is in invalid state."); + } byteRawChannel.reset(); if (isByteBuf) { @@ -296,8 +334,14 @@ private void encryptMore() throws IOException { } else { region.transferTo(byteRawChannel, region.transferred()); } - cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); - cos.flush(); + + try { + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + } catch (InternalError ie) { + handler.reportError(); + throw ie; + } currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), 0, byteEncChannel.length()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index 46b6305363bd0..382b7337d715f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -20,10 +20,13 @@ import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import java.util.Map; +import java.security.InvalidKeyException; import java.util.Random; import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.collect.ImmutableMap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.FileRegion; @@ -189,4 +192,18 @@ public Long answer(InvocationOnMock invocationOnMock) throws Throwable { server.close(); } } + + @Test(expected = InvalidKeyException.class) + public void testBadKeySize() throws Exception { + Map mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42"); + TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf)); + + try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) { + engine.challenge(); + fail("Should have failed to create challenge message."); + + // Call close explicitly to make sure it's idempotent. + engine.close(); + } + } } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 00621976b77f4..18b735b8035ab 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{Closeable, InputStream, IOException, OutputStream} import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties @@ -54,8 +54,10 @@ private[spark] object CryptoStreamUtils extends Logging { val params = new CryptoParams(key, sparkConf) val iv = createInitializationVector(params.conf) os.write(iv) - new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingOutputStream( + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)), + os) } /** @@ -70,8 +72,10 @@ private[spark] object CryptoStreamUtils extends Logging { val helper = new CryptoHelperChannel(channel) helper.write(ByteBuffer.wrap(iv)) - new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingWritableChannel( + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)), + helper) } /** @@ -84,8 +88,10 @@ private[spark] object CryptoStreamUtils extends Logging { val iv = new Array[Byte](IV_LENGTH_IN_BYTES) ByteStreams.readFully(is, iv) val params = new CryptoParams(key, sparkConf) - new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingInputStream( + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)), + is) } /** @@ -100,8 +106,10 @@ private[spark] object CryptoStreamUtils extends Logging { JavaUtils.readFully(channel, buf) val params = new CryptoParams(key, sparkConf) - new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingReadableChannel( + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)), + channel) } def toCryptoConf(conf: SparkConf): Properties = { @@ -157,6 +165,117 @@ private[spark] object CryptoStreamUtils extends Logging { } + /** + * SPARK-25535. The commons-cryto library will throw InternalError if something goes + * wrong, and leave bad state behind in the Java wrappers, so it's not safe to use them + * afterwards. This wrapper detects that situation and avoids further calls into the + * commons-crypto code, while still allowing the underlying streams to be closed. + * + * This should be removed once CRYPTO-141 is fixed (and Spark upgrades its commons-crypto + * dependency). + */ + trait BaseErrorHandler extends Closeable { + + private var closed = false + + /** The encrypted stream that may get into an unhealthy state. */ + protected def cipherStream: Closeable + + /** + * The underlying stream that is being wrapped by the encrypted stream, so that it can be + * closed even if there's an error in the crypto layer. + */ + protected def original: Closeable + + protected def safeCall[T](fn: => T): T = { + if (closed) { + throw new IOException("Cipher stream is closed.") + } + try { + fn + } catch { + case ie: InternalError => + closed = true + original.close() + throw ie + } + } + + override def close(): Unit = { + if (!closed) { + cipherStream.close() + } + } + + } + + // Visible for testing. + class ErrorHandlingReadableChannel( + protected val cipherStream: ReadableByteChannel, + protected val original: ReadableByteChannel) + extends ReadableByteChannel with BaseErrorHandler { + + override def read(src: ByteBuffer): Int = safeCall { + cipherStream.read(src) + } + + override def isOpen(): Boolean = cipherStream.isOpen() + + } + + private class ErrorHandlingInputStream( + protected val cipherStream: InputStream, + protected val original: InputStream) + extends InputStream with BaseErrorHandler { + + override def read(b: Array[Byte]): Int = safeCall { + cipherStream.read(b) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = safeCall { + cipherStream.read(b, off, len) + } + + override def read(): Int = safeCall { + cipherStream.read() + } + } + + private class ErrorHandlingWritableChannel( + protected val cipherStream: WritableByteChannel, + protected val original: WritableByteChannel) + extends WritableByteChannel with BaseErrorHandler { + + override def write(src: ByteBuffer): Int = safeCall { + cipherStream.write(src) + } + + override def isOpen(): Boolean = cipherStream.isOpen() + + } + + private class ErrorHandlingOutputStream( + protected val cipherStream: OutputStream, + protected val original: OutputStream) + extends OutputStream with BaseErrorHandler { + + override def flush(): Unit = safeCall { + cipherStream.flush() + } + + override def write(b: Array[Byte]): Unit = safeCall { + cipherStream.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = safeCall { + cipherStream.write(b, off, len) + } + + override def write(b: Int): Unit = safeCall { + cipherStream.write(b) + } + } + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { val keySpec = new SecretKeySpec(key, "AES") diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 78f618f8a2163..0d3611c80b8d0 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,13 +16,16 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} -import java.nio.channels.Channels +import java.io._ +import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel} import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Files import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.apache.spark._ import org.apache.spark.internal.config._ @@ -164,6 +167,36 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("error handling wrapper") { + val wrapped = mock(classOf[ReadableByteChannel]) + val decrypted = mock(classOf[ReadableByteChannel]) + val errorHandler = new CryptoStreamUtils.ErrorHandlingReadableChannel(decrypted, wrapped) + + when(decrypted.read(any(classOf[ByteBuffer]))) + .thenThrow(new IOException()) + .thenThrow(new InternalError()) + .thenReturn(1) + + val out = ByteBuffer.allocate(1) + intercept[IOException] { + errorHandler.read(out) + } + intercept[InternalError] { + errorHandler.read(out) + } + + val e = intercept[IOException] { + errorHandler.read(out) + } + assert(e.getMessage().contains("is closed")) + errorHandler.close() + + verify(decrypted, times(2)).read(any(classOf[ByteBuffer])) + verify(wrapped, never()).read(any(classOf[ByteBuffer])) + verify(decrypted, never()).close() + verify(wrapped, times(1)).close() + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) }