From 17eb1872140f28c95d9027480ba18d30af98a15d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 14 Aug 2015 15:11:14 -0700 Subject: [PATCH 1/6] [SPARK-10004] [shuffle] Perform auth checks when clients read shuffle data. To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. --- .../network/netty/NettyBlockRpcServer.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- .../spark/network/client/TransportClient.java | 20 +++ .../network/sasl/SaslClientBootstrap.java | 2 + .../spark/network/sasl/SaslRpcHandler.java | 1 + .../server/OneForOneStreamManager.java | 31 +++- .../spark/network/server/StreamManager.java | 9 + .../server/TransportRequestHandler.java | 1 + .../shuffle/ExternalShuffleBlockHandler.java | 16 +- .../network/sasl/SaslIntegrationSuite.java | 159 +++++++++++++++--- .../ExternalShuffleBlockHandlerSuite.java | 2 +- 11 files changed, 210 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2b..fef771036cf40 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -55,7 +56,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) + val streamId = streamManager.registerStream(appId, blocks.iterator) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index d650d5fe73087..320a602fdae91 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index e8e7f06247d3e..027a8f2625be3 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -70,6 +70,7 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; + private String clientId; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); @@ -84,6 +85,24 @@ public SocketAddress getSocketAddress() { return channel.remoteAddress(); } + /** + * Returns the ID used by the client to authenticate itself when authentication is enabled. + * + * @return The client ID. + */ + public String getClientId() { + return clientId; + } + + /** + * Sets the authenticated client ID. This is meant to be used by the authentication layer; + * trying to set a different client ID after it's been set will result in an exception. + */ + public void setClientId(String id) { + Preconditions.checkState(clientId == null, "Client ID has already been set."); + this.clientId = id; + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * @@ -207,6 +226,7 @@ public void close() { public String toString() { return Objects.toStringHelper(this) .add("remoteAdress", channel.remoteAddress()) + .add("clientId", clientId) .add("isActive", isActive()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 185ba2ef3bb1f..69923769d44b4 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -77,6 +77,8 @@ public void doBootstrap(TransportClient client, Channel channel) { payload = saslClient.response(response); } + client.setClientId(appId); + if (encrypt) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { throw new RuntimeException( diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index be6165caf3c74..3f2ebe32887b8 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -81,6 +81,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback if (saslServer == null) { // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, conf.saslServerAlwaysEncrypt()); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index c95e64e8e2cda..e671854da1cae 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -24,13 +24,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; - -import com.google.common.base.Preconditions; +import org.apache.spark.network.client.TransportClient; /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually @@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager { /** State of a single stream. */ private static class StreamState { + final String appId; final Iterator buffers; // The channel associated to the stream @@ -53,7 +54,8 @@ private static class StreamState { // that the caller only requests each chunk one at a time, in order. int curChunk = 0; - StreamState(Iterator buffers) { + StreamState(String appId, Iterator buffers) { + this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); } } @@ -109,15 +111,34 @@ public void connectionTerminated(Channel channel) { } } + @Override + public void checkAuthorization(TransportClient client, long streamId) { + if (client.getClientId() != null) { + StreamState state = streams.get(streamId); + Preconditions.checkArgument(state != null, "Unknown stream ID."); + if (!client.getClientId().equals(state.appId)) { + throw new SecurityException(String.format( + "Client %s not authorized to read stream %d (app %s).", + client.getClientId(), + streamId, + state.appId)); + } + } + } + /** * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a * client connection is closed before the iterator is fully drained, then the remaining buffers * will all be release()'d. + * + * If an app ID is provided, only callers who've authenticated with the given app ID will be + * allowed to fetch from this stream. */ - public long registerStream(Iterator buffers) { + public long registerStream(String appId, Iterator buffers) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(buffers)); + streams.put(myStreamId, new StreamState(appId, buffers)); return myStreamId; } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 929f789bf9d24..aaa677c965640 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -20,6 +20,7 @@ import io.netty.channel.Channel; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; /** * The StreamManager is used to fetch individual chunks from a stream. This is used in @@ -60,4 +61,12 @@ public void registerChannel(Channel channel, long streamId) { } * to read from the associated streams again, so any state can be cleaned up. */ public void connectionTerminated(Channel channel) { } + + /** + * Verify that the client is authorized to read from the given stream. + * + * @throws SecurityException If client is not authorized. + */ + public void checkAuthorization(TransportClient client, long streamId) { } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e5159ab56d0d4..df6027805838d 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -97,6 +97,7 @@ private void processFetchRequest(final ChunkFetchRequest req) { ManagedBuffer buf; try { + streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index db9dc4f17cee9..c7f6ead454d61 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -55,7 +55,7 @@ public ExternalShuffleBlockHandler(TransportConf conf) { /** Enables mocking out the StreamManager and BlockManager. */ @VisibleForTesting - ExternalShuffleBlockHandler( + public ExternalShuffleBlockHandler( OneForOneStreamManager streamManager, ExternalShuffleBlockResolver blockManager) { this.streamManager = streamManager; @@ -74,17 +74,19 @@ protected void handleMessage( RpcResponseCallback callback) { if (msgObj instanceof OpenBlocks) { OpenBlocks msg = (OpenBlocks) msgObj; - List blocks = Lists.newArrayList(); + checkAuth(client, msg.appId); + List blocks = Lists.newArrayList(); for (String blockId : msg.blockIds) { blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); } - long streamId = streamManager.registerStream(blocks.iterator()); + long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; + checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); callback.onSuccess(new byte[0]); @@ -105,4 +107,12 @@ public StreamManager getStreamManager() { public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + + private void checkAuth(TransportClient client, String appId) { + if (client.getClientId() != null && !client.getClientId().equals(appId)) { + throw new SecurityException(String.format( + "Client for %s not authorized for application %s.", client.getClientId(), appId)); + } + } + } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 382f613ecbb1b..c96d9fb4ce4d9 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -19,17 +19,24 @@ import java.io.IOException; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import com.google.common.collect.Lists; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; @@ -39,44 +46,41 @@ import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.shuffle.BlockFetchingListener; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; +import org.apache.spark.network.shuffle.OneForOneBlockFetcher; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { - static ExternalShuffleBlockHandler handler; + private final Logger logger = LoggerFactory.getLogger(SaslIntegrationSuite.class); + static TransportServer server; static TransportConf conf; static TransportContext context; + static SecretKeyHolder secretKeyHolder; TransportClientFactory clientFactory; - /** Provides a secret key holder which always returns the given secret key. */ - static class TestSecretKeyHolder implements SecretKeyHolder { - - private final String secretKey; - - TestSecretKeyHolder(String secretKey) { - this.secretKey = secretKey; - } - - @Override - public String getSaslUser(String appId) { - return "user"; - } - @Override - public String getSecretKey(String appId) { - return secretKey; - } - } - - @BeforeClass public static void beforeAll() throws IOException { - SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); conf = new TransportConf(new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); + secretKeyHolder = mock(SecretKeyHolder.class); + when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1"); + when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1"); + when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2"); + when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2"); + when(secretKeyHolder.getSaslUser(anyString())).thenReturn("someUser"); + when(secretKeyHolder.getSecretKey(anyString())).thenReturn("somePassword"); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); server = context.createServer(Arrays.asList(bootstrap)); } @@ -99,7 +103,7 @@ public void afterEach() { public void testGoodClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; @@ -111,7 +115,7 @@ public void testGoodClient() throws IOException { public void testBadClient() { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + new SaslClientBootstrap(conf, "unknown-app", secretKeyHolder))); try { // Bootstrap should fail on startup. @@ -149,7 +153,7 @@ public void testNoSaslServer() { TransportContext context = new TransportContext(conf, handler); clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); TransportServer server = context.createServer(); try { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -160,6 +164,111 @@ public void testNoSaslServer() { } } + /** + * This test is not actually testing SASL behavior, but testing that the shuffle service + * performs correct authorization checks based on the SASL authentication data. + */ + @Test + public void testAppIsolation() throws Exception { + // Start a new server with the correct RPC handler to serve block data. + ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class); + ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler( + new OneForOneStreamManager(), blockResolver); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + TransportContext blockServerContext = new TransportContext(conf, blockHandler); + TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap)); + + TransportClient client1 = null; + TransportClient client2 = null; + TransportClientFactory clientFactory2 = null; + try { + // Create a client, and make a request to fetch blocks from a different app. + clientFactory = blockServerContext.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + client1 = clientFactory.createClient(TestUtils.getLocalHost(), + blockServer.getPort()); + + final AtomicBoolean result = new AtomicBoolean(false); + + BlockFetchingListener listener = new BlockFetchingListener() { + @Override + public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + notifyAll(); + } + + @Override + public synchronized void onBlockFetchFailure(String blockId, Throwable exception) { + result.set(exception.getMessage().contains(SecurityException.class.getName())); + notifyAll(); + } + }; + + String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" }; + OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", + blockIds, listener); + synchronized (listener) { + fetcher.start(); + listener.wait(); + } + assertTrue("Should have failed to fetch blocks from non-authorized app.", result.get()); + + // Register an executor so that the next steps work. + ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( + new String[] { System.getProperty("java.io.tmpdir") }, 1, + "org.apache.spark.shuffle.sort.SortShuffleManager"); + RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); + client1.sendRpcSync(regmsg.toByteArray(), 10000); + + // Make a successful request to fetch blocks, which creates a new stream. But do not actually + // fetch any blocks, to keep the stream open. + result.set(false); + OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); + byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + long streamId = stream.streamId; + + // Create a second client, authenticated with a different app ID, and try to read from + // the stream created for the previous app. + clientFactory2 = blockServerContext.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); + client2 = clientFactory2.createClient(TestUtils.getLocalHost(), + blockServer.getPort()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { + notifyAll(); + } + + @Override + public synchronized void onFailure(int chunkIndex, Throwable e) { + result.set(e.getMessage().contains(SecurityException.class.getName())); + notifyAll(); + } + }; + + result.set(false); + synchronized (callback) { + client2.fetchChunk(streamId, 0, callback); + callback.wait(); + } + assertTrue("Should have failed to fetch blocks from non-authorized stream.", result.get()); + } finally { + if (client1 != null) { + client1.close(); + } + if (client2 != null) { + client2.close(); + } + if (clientFactory2 != null) { + clientFactory2.close(); + } + blockServer.close(); + } + } + /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 1d197497b7c8f..e61390cf57061 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -93,7 +93,7 @@ public void testOpenShuffleBlocks() { @SuppressWarnings("unchecked") ArgumentCaptor> stream = (ArgumentCaptor>) (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(stream.capture()); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); Iterator buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); From c68deabdf4f0a5070a928802c18e0330b38c3df0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 15 Aug 2015 12:03:33 -0700 Subject: [PATCH 2/6] Unblock jenkins build (mima check failed). --- project/MimaExcludes.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 88745dc086a04..714ce3cd9b1de 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,6 +37,7 @@ object MimaExcludes { case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("network"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), From 292a2995a0702f92111e07d316a740574da7db4e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 17 Aug 2015 13:10:35 -0700 Subject: [PATCH 3/6] Feedback + fix a test that was failing to fail. --- .../spark/network/client/TransportClient.java | 7 ++--- .../network/sasl/SaslIntegrationSuite.java | 26 ++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 027a8f2625be3..0241ee8dc1e9f 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -88,15 +88,16 @@ public SocketAddress getSocketAddress() { /** * Returns the ID used by the client to authenticate itself when authentication is enabled. * - * @return The client ID. + * @return The client ID, or null if authentication is disabled. */ public String getClientId() { return clientId; } /** - * Sets the authenticated client ID. This is meant to be used by the authentication layer; - * trying to set a different client ID after it's been set will result in an exception. + * Sets the authenticated client ID. This is meant to be used by the authentication layer. + * + * Trying to set a different client ID after it's been set will result in an exception. */ public void setClientId(String id) { Preconditions.checkState(clientId == null, "Client ID has already been set."); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c96d9fb4ce4d9..ddef101f6c7ec 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -78,8 +78,8 @@ public static void beforeAll() throws IOException { when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1"); when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2"); when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2"); - when(secretKeyHolder.getSaslUser(anyString())).thenReturn("someUser"); - when(secretKeyHolder.getSecretKey(anyString())).thenReturn("somePassword"); + when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); + when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password"); TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); server = context.createServer(Arrays.asList(bootstrap)); @@ -113,13 +113,17 @@ public void testGoodClient() throws IOException { @Test public void testBadClient() { + SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class); + when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); + when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password"); clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "unknown-app", secretKeyHolder))); + new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); try { // Bootstrap should fail on startup. clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + fail("Connection should have failed."); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } @@ -189,7 +193,7 @@ public void testAppIsolation() throws Exception { client1 = clientFactory.createClient(TestUtils.getLocalHost(), blockServer.getPort()); - final AtomicBoolean result = new AtomicBoolean(false); + final AtomicBoolean gotSecurityException = new AtomicBoolean(false); BlockFetchingListener listener = new BlockFetchingListener() { @Override @@ -199,7 +203,8 @@ public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) @Override public synchronized void onBlockFetchFailure(String blockId, Throwable exception) { - result.set(exception.getMessage().contains(SecurityException.class.getName())); + gotSecurityException.set( + exception.getMessage().contains(SecurityException.class.getName())); notifyAll(); } }; @@ -211,7 +216,8 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable exception fetcher.start(); listener.wait(); } - assertTrue("Should have failed to fetch blocks from non-authorized app.", result.get()); + assertTrue("Should have failed to fetch blocks from non-authorized app.", + gotSecurityException.get()); // Register an executor so that the next steps work. ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( @@ -222,7 +228,6 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable exception // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. - result.set(false); OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); @@ -244,17 +249,18 @@ public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { @Override public synchronized void onFailure(int chunkIndex, Throwable e) { - result.set(e.getMessage().contains(SecurityException.class.getName())); + gotSecurityException.set(e.getMessage().contains(SecurityException.class.getName())); notifyAll(); } }; - result.set(false); + gotSecurityException.set(false); synchronized (callback) { client2.fetchChunk(streamId, 0, callback); callback.wait(); } - assertTrue("Should have failed to fetch blocks from non-authorized stream.", result.get()); + assertTrue("Should have failed to fetch blocks from non-authorized stream.", + gotSecurityException.get()); } finally { if (client1 != null) { client1.close(); From fadff2790d765fba458999fb3952d626fa15214b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 17 Aug 2015 14:37:04 -0700 Subject: [PATCH 4/6] Clean imports. --- .../org/apache/spark/network/sasl/SaslIntegrationSuite.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index ddef101f6c7ec..17206227f0a10 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -20,15 +20,12 @@ import java.io.IOException; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import com.google.common.collect.Lists; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -59,8 +56,6 @@ import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { - private final Logger logger = LoggerFactory.getLogger(SaslIntegrationSuite.class); - static TransportServer server; static TransportConf conf; static TransportContext context; From 4d19ed51616ff6019bcb3b762d82823ced49b675 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 20 Aug 2015 11:15:45 -0700 Subject: [PATCH 5/6] Better error messages when tests fail. --- .../network/sasl/SaslIntegrationSuite.java | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 17206227f0a10..5cb0e4d4a6458 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -19,7 +19,7 @@ import java.io.IOException; import java.util.Arrays; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import com.google.common.collect.Lists; import org.junit.After; @@ -188,7 +188,7 @@ public void testAppIsolation() throws Exception { client1 = clientFactory.createClient(TestUtils.getLocalHost(), blockServer.getPort()); - final AtomicBoolean gotSecurityException = new AtomicBoolean(false); + final AtomicReference exception = new AtomicReference<>(); BlockFetchingListener listener = new BlockFetchingListener() { @Override @@ -197,9 +197,8 @@ public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) } @Override - public synchronized void onBlockFetchFailure(String blockId, Throwable exception) { - gotSecurityException.set( - exception.getMessage().contains(SecurityException.class.getName())); + public synchronized void onBlockFetchFailure(String blockId, Throwable t) { + exception.set(t); notifyAll(); } }; @@ -211,8 +210,7 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable exception fetcher.start(); listener.wait(); } - assertTrue("Should have failed to fetch blocks from non-authorized app.", - gotSecurityException.get()); + checkSecurityException(exception.get()); // Register an executor so that the next steps work. ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( @@ -243,19 +241,18 @@ public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { } @Override - public synchronized void onFailure(int chunkIndex, Throwable e) { - gotSecurityException.set(e.getMessage().contains(SecurityException.class.getName())); + public synchronized void onFailure(int chunkIndex, Throwable t) { + exception.set(t); notifyAll(); } }; - gotSecurityException.set(false); + exception.set(null); synchronized (callback) { client2.fetchChunk(streamId, 0, callback); callback.wait(); } - assertTrue("Should have failed to fetch blocks from non-authorized stream.", - gotSecurityException.get()); + checkSecurityException(exception.get()); } finally { if (client1 != null) { client1.close(); @@ -282,4 +279,10 @@ public StreamManager getStreamManager() { return new OneForOneStreamManager(); } } + + private void checkSecurityException(Throwable t) { + assertNotNull("No exception was caught.", t); + assertTrue("Expected SecurityException.", + t.getMessage().contains(SecurityException.class.getName())); + } } From b491ac7f44880208a5e22d3dea4194eff6c53590 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 28 Aug 2015 19:57:09 -0700 Subject: [PATCH 6/6] Make clientId @Nullable. --- network/common/pom.xml | 4 ++++ .../java/org/apache/spark/network/client/TransportClient.java | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/network/common/pom.xml b/network/common/pom.xml index 7dc3068ab8cb7..4141fcb8267a5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,6 +48,10 @@ slf4j-api provided + + com.google.code.findbugs + jsr305 +