Skip to content

Commit d1124e4

Browse files
author
Andrew Or
committed
Add security to shuffle service (INCOMPLETE)
This commit adds logic that allows the shuffle server to use SASL to authenticate shuffle requests. This is currently not working yet for two reasons: (1) The ExternalShuffleClient doesn't actually have the authentication logic yet. This will be implemented shortly in a separate PR. (2) All supported Yarn versions use a different version of guava that does not have the base encoding util functions that the SASL server uses. This will also be fixed in a separate PR.
1 parent 5f8a96f commit d1124e4

File tree

6 files changed

+205
-25
lines changed

6 files changed

+205
-25
lines changed

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ private[spark] class BlockManager(
9595

9696
private[spark]
9797
val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
98-
private val externalShuffleServicePort = Utils.getExternalShuffleServicePort(conf)
98+
99+
// Port used by the external shuffle service. In Yarn mode, this may be already be
100+
// set through the Hadoop configuration as the server is launched in the Yarn NM.
101+
private val externalShuffleServicePort =
102+
Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
99103

100104
// Check that we're not using external shuffle service with consolidated shuffle files.
101105
if (externalShuffleServiceEnabled

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,19 +1783,19 @@ private[spark] object Utils extends Logging {
17831783
}.getOrElse("Unknown")
17841784

17851785
/**
1786-
* Return the port used in the external shuffle service as specified through
1787-
* `spark.shuffle.service.port`. In Yarn, this is set in the Hadoop configuration.
1786+
* Return the value of a config either through the SparkConf or the Hadoop configuration
1787+
* if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf
1788+
* if the key is not set in the Hadoop configuration.
17881789
*/
1789-
def getExternalShuffleServicePort(conf: SparkConf): Int = {
1790-
val shuffleServicePortKey = "spark.shuffle.service.port"
1791-
val sparkPort = conf.getInt(shuffleServicePortKey, 7337)
1790+
def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = {
1791+
val sparkValue = conf.get(key, default)
17921792
if (SparkHadoopUtil.get.isYarnMode) {
1793-
val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
1794-
hadoopConf.getInt(shuffleServicePortKey, sparkPort)
1793+
SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue)
17951794
} else {
1796-
sparkPort
1795+
sparkValue
17971796
}
17981797
}
1798+
17991799
}
18001800

18011801
/**
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.sasl;
19+
20+
import java.lang.Override;
21+
import java.nio.ByteBuffer;
22+
import java.nio.charset.Charset;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
25+
import org.slf4j.Logger;
26+
import org.slf4j.LoggerFactory;
27+
28+
import org.apache.spark.network.sasl.SecretKeyHolder;
29+
30+
/**
31+
* A class that manages shuffle secret used by the external shuffle service.
32+
*/
33+
public class ShuffleSecretManager implements SecretKeyHolder {
34+
private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class);
35+
private final ConcurrentHashMap<String, String> shuffleSecretMap;
36+
37+
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
38+
39+
// Spark user used for authenticating SASL connections
40+
// Note that this should match the value in org.apache.spark.SecurityManager
41+
private static final String SPARK_SASL_USER = "sparkSaslUser";
42+
43+
/**
44+
* Convert the given string to a byte buffer that can be converted back to a string
45+
* through {@link #bytesToString(ByteBuffer)}. This is used if the external shuffle
46+
* service represents shuffle secrets as bytes buffers instead of strings.
47+
*/
48+
public static ByteBuffer stringToBytes(String s) {
49+
return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET));
50+
}
51+
52+
/**
53+
* Convert the given byte buffer to a string that can be converted back to a byte
54+
* buffer through {@link #stringToBytes(String)}. This is used if the external shuffle
55+
* service represents shuffle secrets as bytes buffers instead of strings.
56+
*/
57+
public static String bytesToString(ByteBuffer b) {
58+
return new String(b.array(), UTF8_CHARSET);
59+
}
60+
61+
public ShuffleSecretManager() {
62+
shuffleSecretMap = new ConcurrentHashMap<String, String>();
63+
}
64+
65+
/**
66+
* Register the specified application with its secret.
67+
* Executors need to first authenticate themselves with the same secret before
68+
* the fetching shuffle files written by other executors in this application.
69+
*/
70+
public void registerApp(String appId, String shuffleSecret) {
71+
if (!shuffleSecretMap.contains(appId)) {
72+
shuffleSecretMap.put(appId, shuffleSecret);
73+
logger.info("Registered shuffle secret for application {}", appId);
74+
} else {
75+
logger.debug("Application {} already registered", appId);
76+
}
77+
}
78+
79+
/**
80+
* Register the specified application with its secret specified as a byte buffer.
81+
*/
82+
public void registerApp(String appId, ByteBuffer shuffleSecret) {
83+
registerApp(appId, bytesToString(shuffleSecret));
84+
}
85+
86+
/**
87+
* Unregister the specified application along with its secret.
88+
* This is called when an application terminates.
89+
*/
90+
public void unregisterApp(String appId) {
91+
if (shuffleSecretMap.contains(appId)) {
92+
shuffleSecretMap.remove(appId);
93+
logger.info("Unregistered shuffle secret for application {}", appId);
94+
} else {
95+
logger.warn("Attempted to unregister application {} when it is not registered", appId);
96+
}
97+
}
98+
99+
/**
100+
* Return the Spark user for authenticating SASL connections.
101+
*/
102+
@Override
103+
public String getSaslUser(String appId) {
104+
return SPARK_SASL_USER;
105+
}
106+
107+
/**
108+
* Return the secret key registered with the specified application.
109+
* This key is used to authenticate the executors in the application
110+
* before they can fetch shuffle files from the external shuffle service.
111+
* If the application is not registered, return null.
112+
*/
113+
@Override
114+
public String getSecretKey(String appId) {
115+
return shuffleSecretMap.get(appId);
116+
}
117+
}

network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import org.slf4j.LoggerFactory;
3333

3434
import org.apache.spark.network.TransportContext;
35+
import org.apache.spark.network.sasl.SaslRpcHandler;
36+
import org.apache.spark.network.sasl.ShuffleSecretManager;
3537
import org.apache.spark.network.server.RpcHandler;
3638
import org.apache.spark.network.server.TransportServer;
3739
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
@@ -44,9 +46,18 @@
4446
public class YarnShuffleService extends AuxiliaryService {
4547
private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class);
4648

49+
// Port on which the shuffle server listens for fetch requests
4750
private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port";
4851
private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337;
4952

53+
// Whether the shuffle server should authenticate fetch requests
54+
private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate";
55+
private static final boolean DEFAULT_SPARK_AUTHENTICATE = false;
56+
57+
// An entity that manages the shuffle secret per application
58+
// This is used only if authentication is enabled
59+
private ShuffleSecretManager secretManager;
60+
5061
// Actual server that serves the shuffle files
5162
private TransportServer shuffleServer = null;
5263

@@ -55,58 +66,88 @@ public YarnShuffleService() {
5566
logger.info("Initializing Yarn shuffle service for Spark");
5667
}
5768

69+
/**
70+
* Return whether authentication is enabled as specified by the configuration.
71+
* If so, fetch requests will fail unless the appropriate authentication secret
72+
* for the application is provided.
73+
*/
74+
private boolean isAuthenticationEnabled() {
75+
return secretManager != null;
76+
}
77+
5878
/**
5979
* Start the shuffle server with the given configuration.
6080
*/
6181
@Override
6282
protected void serviceInit(Configuration conf) {
6383
try {
84+
// If authentication is enabled, set up the shuffle server to use a
85+
// special RPC handler that filters out unauthenticated fetch requests
86+
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
87+
RpcHandler rpcHandler = new ExternalShuffleBlockHandler();
88+
if (authEnabled) {
89+
secretManager = new ShuffleSecretManager();
90+
rpcHandler = new SaslRpcHandler(rpcHandler, secretManager);
91+
}
92+
6493
int port = conf.getInt(
6594
SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
6695
TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf));
67-
RpcHandler rpcHandler = new ExternalShuffleBlockHandler();
6896
TransportContext transportContext = new TransportContext(transportConf, rpcHandler);
6997
shuffleServer = transportContext.createServer(port);
70-
logger.info("Started Yarn shuffle service for Spark on port " + port);
98+
String authEnabledString = authEnabled ? "enabled" : "not enabled";
99+
logger.info("Started Yarn shuffle service for Spark on port {}. " +
100+
"Authentication is {}.", port, authEnabledString);
71101
} catch (Exception e) {
72102
logger.error("Exception in starting Yarn shuffle service for Spark", e);
73103
}
74104
}
75105

76106
@Override
77107
public void initializeApplication(ApplicationInitializationContext context) {
78-
ApplicationId appId = context.getApplicationId();
79-
logger.debug("Initializing application " + appId + "!");
108+
String appId = context.getApplicationId().toString();
109+
ByteBuffer shuffleSecret = context.getApplicationDataForService();
110+
logger.debug("Initializing application {}", appId);
111+
if (isAuthenticationEnabled()) {
112+
secretManager.registerApp(appId, shuffleSecret);
113+
}
80114
}
81115

82116
@Override
83117
public void stopApplication(ApplicationTerminationContext context) {
84-
ApplicationId appId = context.getApplicationId();
85-
logger.debug("Stopping application " + appId + "!");
86-
}
87-
88-
@Override
89-
public ByteBuffer getMetaData() {
90-
logger.debug("Getting meta data");
91-
return ByteBuffer.allocate(0);
118+
String appId = context.getApplicationId().toString();
119+
logger.debug("Stopping application {}", appId);
120+
if (isAuthenticationEnabled()) {
121+
secretManager.unregisterApp(appId);
122+
}
92123
}
93124

94125
@Override
95126
public void initializeContainer(ContainerInitializationContext context) {
96127
ContainerId containerId = context.getContainerId();
97-
logger.debug("Initializing container " + containerId + "!");
128+
logger.debug("Initializing container {}", containerId);
98129
}
99130

100131
@Override
101132
public void stopContainer(ContainerTerminationContext context) {
102133
ContainerId containerId = context.getContainerId();
103-
logger.debug("Stopping container " + containerId + "!");
134+
logger.debug("Stopping container {}", containerId);
104135
}
105136

137+
/**
138+
* Close the shuffle server to clean up any associated state.
139+
*/
106140
@Override
107141
protected void serviceStop() {
108142
if (shuffleServer != null) {
109143
shuffleServer.close();
110144
}
111145
}
146+
147+
// Not currently used
148+
@Override
149+
public ByteBuffer getMetaData() {
150+
return ByteBuffer.allocate(0);
151+
}
152+
112153
}

yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC
3636
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils}
3737

3838
import org.apache.spark.{SecurityManager, SparkConf, Logging}
39+
import org.apache.spark.network.sasl.ShuffleSecretManager
3940

4041
@deprecated("use yarn/stable", "1.2.0")
4142
class ExecutorRunnable(
@@ -93,7 +94,15 @@ class ExecutorRunnable(
9394
// If external shuffle service is enabled, register with the
9495
// Yarn shuffle service already started on the node manager
9596
if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) {
96-
ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> ByteBuffer.allocate(0)))
97+
val secretString = securityMgr.getSecretKey()
98+
val secretBytes =
99+
if (secretString != null) {
100+
ShuffleSecretManager.stringToBytes(secretString)
101+
} else {
102+
// Authentication is not enabled, so just provide dummy metadata
103+
ByteBuffer.allocate(0)
104+
}
105+
ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes))
97106
}
98107

99108
// Send the start request to the ContainerManager

yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC
3636
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
3737

3838
import org.apache.spark.{SecurityManager, SparkConf, Logging}
39+
import org.apache.spark.network.sasl.ShuffleSecretManager
3940

4041

4142
class ExecutorRunnable(
@@ -92,7 +93,15 @@ class ExecutorRunnable(
9293
// If external shuffle service is enabled, register with the
9394
// Yarn shuffle service already started on the node manager
9495
if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) {
95-
ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> ByteBuffer.allocate(0)))
96+
val secretString = securityMgr.getSecretKey()
97+
val secretBytes =
98+
if (secretString != null) {
99+
ShuffleSecretManager.stringToBytes(secretString)
100+
} else {
101+
// Authentication is not enabled, so just provide dummy metadata
102+
ByteBuffer.allocate(0)
103+
}
104+
ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes))
96105
}
97106

98107
// Send the start request to the ContainerManager

0 commit comments

Comments
 (0)