Skip to content

Commit 2da3a9e

Browse files
author
Marcelo Vanzin
committed
[SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data.
To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. Author: Marcelo Vanzin <[email protected]> Closes #8218 from vanzin/SPARK-10004.
1 parent fc48307 commit 2da3a9e

File tree

13 files changed

+221
-36
lines changed

13 files changed

+221
-36
lines changed

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel}
3838
* is equivalent to one Spark-level shuffle block.
3939
*/
4040
class NettyBlockRpcServer(
41+
appId: String,
4142
serializer: Serializer,
4243
blockManager: BlockDataManager)
4344
extends RpcHandler with Logging {
@@ -55,7 +56,7 @@ class NettyBlockRpcServer(
5556
case openBlocks: OpenBlocks =>
5657
val blocks: Seq[ManagedBuffer] =
5758
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
58-
val streamId = streamManager.registerStream(blocks.iterator.asJava)
59+
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
5960
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
6061
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
6162

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
4949
private[this] var appId: String = _
5050

5151
override def init(blockDataManager: BlockDataManager): Unit = {
52-
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
52+
val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
5353
var serverBootstrap: Option[TransportServerBootstrap] = None
5454
var clientBootstrap: Option[TransportClientBootstrap] = None
5555
if (authEnabled) {

network/common/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
<artifactId>slf4j-api</artifactId>
4949
<scope>provided</scope>
5050
</dependency>
51+
<dependency>
52+
<groupId>com.google.code.findbugs</groupId>
53+
<artifactId>jsr305</artifactId>
54+
</dependency>
5155
<!--
5256
Promote Guava to "compile" so that maven-shade-plugin picks it up (for packaging the Optional
5357
class exposed in the Java API). The plugin will then remove this dependency from the published

network/common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.UUID;
2424
import java.util.concurrent.ExecutionException;
2525
import java.util.concurrent.TimeUnit;
26+
import javax.annotation.Nullable;
2627

2728
import com.google.common.base.Objects;
2829
import com.google.common.base.Preconditions;
@@ -70,6 +71,7 @@ public class TransportClient implements Closeable {
7071

7172
private final Channel channel;
7273
private final TransportResponseHandler handler;
74+
@Nullable private String clientId;
7375

7476
public TransportClient(Channel channel, TransportResponseHandler handler) {
7577
this.channel = Preconditions.checkNotNull(channel);
@@ -84,6 +86,25 @@ public SocketAddress getSocketAddress() {
8486
return channel.remoteAddress();
8587
}
8688

89+
/**
90+
* Returns the ID used by the client to authenticate itself when authentication is enabled.
91+
*
92+
* @return The client ID, or null if authentication is disabled.
93+
*/
94+
public String getClientId() {
95+
return clientId;
96+
}
97+
98+
/**
99+
* Sets the authenticated client ID. This is meant to be used by the authentication layer.
100+
*
101+
* Trying to set a different client ID after it's been set will result in an exception.
102+
*/
103+
public void setClientId(String id) {
104+
Preconditions.checkState(clientId == null, "Client ID has already been set.");
105+
this.clientId = id;
106+
}
107+
87108
/**
88109
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
89110
*
@@ -207,6 +228,7 @@ public void close() {
207228
public String toString() {
208229
return Objects.toStringHelper(this)
209230
.add("remoteAdress", channel.remoteAddress())
231+
.add("clientId", clientId)
210232
.add("isActive", isActive())
211233
.toString();
212234
}

network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ public void doBootstrap(TransportClient client, Channel channel) {
7777
payload = saslClient.response(response);
7878
}
7979

80+
client.setClientId(appId);
81+
8082
if (encrypt) {
8183
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
8284
throw new RuntimeException(

network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback
8181

8282
if (saslServer == null) {
8383
// First message in the handshake, setup the necessary state.
84+
client.setClientId(saslMessage.appId);
8485
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
8586
conf.saslServerAlwaysEncrypt());
8687
}

network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import java.util.concurrent.ConcurrentHashMap;
2525
import java.util.concurrent.atomic.AtomicLong;
2626

27+
import com.google.common.base.Preconditions;
2728
import io.netty.channel.Channel;
2829
import org.slf4j.Logger;
2930
import org.slf4j.LoggerFactory;
3031

3132
import org.apache.spark.network.buffer.ManagedBuffer;
32-
33-
import com.google.common.base.Preconditions;
33+
import org.apache.spark.network.client.TransportClient;
3434

3535
/**
3636
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
@@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {
4444

4545
/** State of a single stream. */
4646
private static class StreamState {
47+
final String appId;
4748
final Iterator<ManagedBuffer> buffers;
4849

4950
// The channel associated to the stream
@@ -53,7 +54,8 @@ private static class StreamState {
5354
// that the caller only requests each chunk one at a time, in order.
5455
int curChunk = 0;
5556

56-
StreamState(Iterator<ManagedBuffer> buffers) {
57+
StreamState(String appId, Iterator<ManagedBuffer> buffers) {
58+
this.appId = appId;
5759
this.buffers = Preconditions.checkNotNull(buffers);
5860
}
5961
}
@@ -109,15 +111,34 @@ public void connectionTerminated(Channel channel) {
109111
}
110112
}
111113

114+
@Override
115+
public void checkAuthorization(TransportClient client, long streamId) {
116+
if (client.getClientId() != null) {
117+
StreamState state = streams.get(streamId);
118+
Preconditions.checkArgument(state != null, "Unknown stream ID.");
119+
if (!client.getClientId().equals(state.appId)) {
120+
throw new SecurityException(String.format(
121+
"Client %s not authorized to read stream %d (app %s).",
122+
client.getClientId(),
123+
streamId,
124+
state.appId));
125+
}
126+
}
127+
}
128+
112129
/**
113130
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
114131
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
115132
* client connection is closed before the iterator is fully drained, then the remaining buffers
116133
* will all be release()'d.
134+
*
135+
* If an app ID is provided, only callers who've authenticated with the given app ID will be
136+
* allowed to fetch from this stream.
117137
*/
118-
public long registerStream(Iterator<ManagedBuffer> buffers) {
138+
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
119139
long myStreamId = nextStreamId.getAndIncrement();
120-
streams.put(myStreamId, new StreamState(buffers));
140+
streams.put(myStreamId, new StreamState(appId, buffers));
121141
return myStreamId;
122142
}
143+
123144
}

network/common/src/main/java/org/apache/spark/network/server/StreamManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.netty.channel.Channel;
2121

2222
import org.apache.spark.network.buffer.ManagedBuffer;
23+
import org.apache.spark.network.client.TransportClient;
2324

2425
/**
2526
* The StreamManager is used to fetch individual chunks from a stream. This is used in
@@ -60,4 +61,12 @@ public void registerChannel(Channel channel, long streamId) { }
6061
* to read from the associated streams again, so any state can be cleaned up.
6162
*/
6263
public void connectionTerminated(Channel channel) { }
64+
65+
/**
66+
* Verify that the client is authorized to read from the given stream.
67+
*
68+
* @throws SecurityException If client is not authorized.
69+
*/
70+
public void checkAuthorization(TransportClient client, long streamId) { }
71+
6372
}

network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ private void processFetchRequest(final ChunkFetchRequest req) {
9797

9898
ManagedBuffer buf;
9999
try {
100+
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
100101
streamManager.registerChannel(channel, req.streamChunkId.streamId);
101102
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
102103
} catch (Exception e) {

network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFi
5858

5959
/** Enables mocking out the StreamManager and BlockManager. */
6060
@VisibleForTesting
61-
ExternalShuffleBlockHandler(
61+
public ExternalShuffleBlockHandler(
6262
OneForOneStreamManager streamManager,
6363
ExternalShuffleBlockResolver blockManager) {
6464
this.streamManager = streamManager;
@@ -77,17 +77,19 @@ protected void handleMessage(
7777
RpcResponseCallback callback) {
7878
if (msgObj instanceof OpenBlocks) {
7979
OpenBlocks msg = (OpenBlocks) msgObj;
80-
List<ManagedBuffer> blocks = Lists.newArrayList();
80+
checkAuth(client, msg.appId);
8181

82+
List<ManagedBuffer> blocks = Lists.newArrayList();
8283
for (String blockId : msg.blockIds) {
8384
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
8485
}
85-
long streamId = streamManager.registerStream(blocks.iterator());
86+
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
8687
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
8788
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
8889

8990
} else if (msgObj instanceof RegisterExecutor) {
9091
RegisterExecutor msg = (RegisterExecutor) msgObj;
92+
checkAuth(client, msg.appId);
9193
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
9294
callback.onSuccess(new byte[0]);
9395

@@ -126,4 +128,12 @@ public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executor
126128
public void close() {
127129
blockManager.close();
128130
}
131+
132+
private void checkAuth(TransportClient client, String appId) {
133+
if (client.getClientId() != null && !client.getClientId().equals(appId)) {
134+
throw new SecurityException(String.format(
135+
"Client for %s not authorized for application %s.", client.getClientId(), appId));
136+
}
137+
}
138+
129139
}

0 commit comments

Comments
 (0)