Skip to content

Commit 17eb187

Browse files
author
Marcelo Vanzin
committed
[SPARK-10004] [shuffle] Perform auth checks when clients read shuffle data.
To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is.
1 parent 8187b3a commit 17eb187

File tree

11 files changed

+210
-36
lines changed

11 files changed

+210
-36
lines changed

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel}
3838
* is equivalent to one Spark-level shuffle block.
3939
*/
4040
class NettyBlockRpcServer(
41+
appId: String,
4142
serializer: Serializer,
4243
blockManager: BlockDataManager)
4344
extends RpcHandler with Logging {
@@ -55,7 +56,7 @@ class NettyBlockRpcServer(
5556
case openBlocks: OpenBlocks =>
5657
val blocks: Seq[ManagedBuffer] =
5758
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
58-
val streamId = streamManager.registerStream(blocks.iterator)
59+
val streamId = streamManager.registerStream(appId, blocks.iterator)
5960
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
6061
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
6162

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
4949
private[this] var appId: String = _
5050

5151
override def init(blockDataManager: BlockDataManager): Unit = {
52-
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
52+
val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
5353
var serverBootstrap: Option[TransportServerBootstrap] = None
5454
var clientBootstrap: Option[TransportClientBootstrap] = None
5555
if (authEnabled) {

network/common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ public class TransportClient implements Closeable {
7070

7171
private final Channel channel;
7272
private final TransportResponseHandler handler;
73+
private String clientId;
7374

7475
public TransportClient(Channel channel, TransportResponseHandler handler) {
7576
this.channel = Preconditions.checkNotNull(channel);
@@ -84,6 +85,24 @@ public SocketAddress getSocketAddress() {
8485
return channel.remoteAddress();
8586
}
8687

88+
/**
89+
* Returns the ID used by the client to authenticate itself when authentication is enabled.
90+
*
91+
* @return The client ID.
92+
*/
93+
public String getClientId() {
94+
return clientId;
95+
}
96+
97+
/**
98+
* Sets the authenticated client ID. This is meant to be used by the authentication layer;
99+
* trying to set a different client ID after it's been set will result in an exception.
100+
*/
101+
public void setClientId(String id) {
102+
Preconditions.checkState(clientId == null, "Client ID has already been set.");
103+
this.clientId = id;
104+
}
105+
87106
/**
88107
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
89108
*
@@ -207,6 +226,7 @@ public void close() {
207226
public String toString() {
208227
return Objects.toStringHelper(this)
209228
.add("remoteAdress", channel.remoteAddress())
229+
.add("clientId", clientId)
210230
.add("isActive", isActive())
211231
.toString();
212232
}

network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ public void doBootstrap(TransportClient client, Channel channel) {
7777
payload = saslClient.response(response);
7878
}
7979

80+
client.setClientId(appId);
81+
8082
if (encrypt) {
8183
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
8284
throw new RuntimeException(

network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback
8181

8282
if (saslServer == null) {
8383
// First message in the handshake, setup the necessary state.
84+
client.setClientId(saslMessage.appId);
8485
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
8586
conf.saslServerAlwaysEncrypt());
8687
}

network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import java.util.concurrent.ConcurrentHashMap;
2525
import java.util.concurrent.atomic.AtomicLong;
2626

27+
import com.google.common.base.Preconditions;
2728
import io.netty.channel.Channel;
2829
import org.slf4j.Logger;
2930
import org.slf4j.LoggerFactory;
3031

3132
import org.apache.spark.network.buffer.ManagedBuffer;
32-
33-
import com.google.common.base.Preconditions;
33+
import org.apache.spark.network.client.TransportClient;
3434

3535
/**
3636
* StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
@@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {
4444

4545
/** State of a single stream. */
4646
private static class StreamState {
47+
final String appId;
4748
final Iterator<ManagedBuffer> buffers;
4849

4950
// The channel associated to the stream
@@ -53,7 +54,8 @@ private static class StreamState {
5354
// that the caller only requests each chunk one at a time, in order.
5455
int curChunk = 0;
5556

56-
StreamState(Iterator<ManagedBuffer> buffers) {
57+
StreamState(String appId, Iterator<ManagedBuffer> buffers) {
58+
this.appId = appId;
5759
this.buffers = Preconditions.checkNotNull(buffers);
5860
}
5961
}
@@ -109,15 +111,34 @@ public void connectionTerminated(Channel channel) {
109111
}
110112
}
111113

114+
@Override
115+
public void checkAuthorization(TransportClient client, long streamId) {
116+
if (client.getClientId() != null) {
117+
StreamState state = streams.get(streamId);
118+
Preconditions.checkArgument(state != null, "Unknown stream ID.");
119+
if (!client.getClientId().equals(state.appId)) {
120+
throw new SecurityException(String.format(
121+
"Client %s not authorized to read stream %d (app %s).",
122+
client.getClientId(),
123+
streamId,
124+
state.appId));
125+
}
126+
}
127+
}
128+
112129
/**
113130
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
114131
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
115132
* client connection is closed before the iterator is fully drained, then the remaining buffers
116133
* will all be release()'d.
134+
*
135+
* If an app ID is provided, only callers who've authenticated with the given app ID will be
136+
* allowed to fetch from this stream.
117137
*/
118-
public long registerStream(Iterator<ManagedBuffer> buffers) {
138+
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
119139
long myStreamId = nextStreamId.getAndIncrement();
120-
streams.put(myStreamId, new StreamState(buffers));
140+
streams.put(myStreamId, new StreamState(appId, buffers));
121141
return myStreamId;
122142
}
143+
123144
}

network/common/src/main/java/org/apache/spark/network/server/StreamManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.netty.channel.Channel;
2121

2222
import org.apache.spark.network.buffer.ManagedBuffer;
23+
import org.apache.spark.network.client.TransportClient;
2324

2425
/**
2526
* The StreamManager is used to fetch individual chunks from a stream. This is used in
@@ -60,4 +61,12 @@ public void registerChannel(Channel channel, long streamId) { }
6061
* to read from the associated streams again, so any state can be cleaned up.
6162
*/
6263
public void connectionTerminated(Channel channel) { }
64+
65+
/**
66+
* Verify that the client is authorized to read from the given stream.
67+
*
68+
* @throws SecurityException If client is not authorized.
69+
*/
70+
public void checkAuthorization(TransportClient client, long streamId) { }
71+
6372
}

network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ private void processFetchRequest(final ChunkFetchRequest req) {
9797

9898
ManagedBuffer buf;
9999
try {
100+
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
100101
streamManager.registerChannel(channel, req.streamChunkId.streamId);
101102
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
102103
} catch (Exception e) {

network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public ExternalShuffleBlockHandler(TransportConf conf) {
5555

5656
/** Enables mocking out the StreamManager and BlockManager. */
5757
@VisibleForTesting
58-
ExternalShuffleBlockHandler(
58+
public ExternalShuffleBlockHandler(
5959
OneForOneStreamManager streamManager,
6060
ExternalShuffleBlockResolver blockManager) {
6161
this.streamManager = streamManager;
@@ -74,17 +74,19 @@ protected void handleMessage(
7474
RpcResponseCallback callback) {
7575
if (msgObj instanceof OpenBlocks) {
7676
OpenBlocks msg = (OpenBlocks) msgObj;
77-
List<ManagedBuffer> blocks = Lists.newArrayList();
77+
checkAuth(client, msg.appId);
7878

79+
List<ManagedBuffer> blocks = Lists.newArrayList();
7980
for (String blockId : msg.blockIds) {
8081
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
8182
}
82-
long streamId = streamManager.registerStream(blocks.iterator());
83+
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
8384
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
8485
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
8586

8687
} else if (msgObj instanceof RegisterExecutor) {
8788
RegisterExecutor msg = (RegisterExecutor) msgObj;
89+
checkAuth(client, msg.appId);
8890
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
8991
callback.onSuccess(new byte[0]);
9092

@@ -105,4 +107,12 @@ public StreamManager getStreamManager() {
105107
public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
106108
blockManager.applicationRemoved(appId, cleanupLocalDirs);
107109
}
110+
111+
private void checkAuth(TransportClient client, String appId) {
112+
if (client.getClientId() != null && !client.getClientId().equals(appId)) {
113+
throw new SecurityException(String.format(
114+
"Client for %s not authorized for application %s.", client.getClientId(), appId));
115+
}
116+
}
117+
108118
}

0 commit comments

Comments
 (0)