Skip to content

Commit 6e9ef10

Browse files
aarondavAndrew Or
authored andcommitted
[SPARK-4277] Support external shuffle service on Standalone Worker
Author: Aaron Davidson <[email protected]> Closes #3142 from aarondav/worker and squashes the following commits: 3780bd7 [Aaron Davidson] Address comments 2dcdfc1 [Aaron Davidson] Add private[worker] 47f49d3 [Aaron Davidson] NettyBlockTransferService shouldn't care about app ids (it's only b/t executors) 258417c [Aaron Davidson] [SPARK-4277] Support external shuffle service on executor
1 parent 96136f2 commit 6e9ef10

File tree

6 files changed

+79
-26
lines changed

6 files changed

+79
-26
lines changed

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -343,15 +343,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with
343343
*/
344344
def getSecretKey(): String = secretKey
345345

346-
override def getSaslUser(appId: String): String = {
347-
val myAppId = sparkConf.getAppId
348-
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
349-
getSaslUser()
350-
}
351-
352-
override def getSecretKey(appId: String): String = {
353-
val myAppId = sparkConf.getAppId
354-
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
355-
getSecretKey()
356-
}
346+
// Default SecurityManager only has a single secret key, so ignore appId.
347+
override def getSaslUser(appId: String): String = getSaslUser()
348+
override def getSecretKey(appId: String): String = getSecretKey()
357349
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.deploy.worker
19+
20+
import org.apache.spark.{Logging, SparkConf, SecurityManager}
21+
import org.apache.spark.network.TransportContext
22+
import org.apache.spark.network.netty.SparkTransportConf
23+
import org.apache.spark.network.sasl.SaslRpcHandler
24+
import org.apache.spark.network.server.TransportServer
25+
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
26+
27+
/**
28+
* Provides a server from which Executors can read shuffle files (rather than reading directly from
29+
* each other), to provide uninterrupted access to the files in the face of executors being turned
30+
* off or killed.
31+
*
32+
* Optionally requires SASL authentication in order to read. See [[SecurityManager]].
33+
*/
34+
private[worker]
35+
class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
36+
extends Logging {
37+
38+
private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
39+
private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
40+
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
41+
42+
private val transportConf = SparkTransportConf.fromSparkConf(sparkConf)
43+
private val blockHandler = new ExternalShuffleBlockHandler()
44+
private val transportContext: TransportContext = {
45+
val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
46+
new TransportContext(transportConf, handler)
47+
}
48+
49+
private var server: TransportServer = _
50+
51+
/** Starts the external shuffle service if the user has configured us to. */
52+
def startIfEnabled() {
53+
if (enabled) {
54+
require(server == null, "Shuffle server already started")
55+
logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
56+
server = transportContext.createServer(port)
57+
}
58+
}
59+
60+
def stop() {
61+
if (enabled && server != null) {
62+
server.close()
63+
server = null
64+
}
65+
}
66+
}

core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ private[spark] class Worker(
111111
val drivers = new HashMap[String, DriverRunner]
112112
val finishedDrivers = new HashMap[String, DriverRunner]
113113

114+
// The shuffle service is not actually started unless configured.
115+
val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
116+
114117
val publicAddress = {
115118
val envVar = System.getenv("SPARK_PUBLIC_DNS")
116119
if (envVar != null) envVar else host
@@ -154,6 +157,7 @@ private[spark] class Worker(
154157
logInfo("Spark home: " + sparkHome)
155158
createWorkDir()
156159
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
160+
shuffleService.startIfEnabled()
157161
webUi = new WorkerWebUI(this, workDir, webUiPort)
158162
webUi.bind()
159163
registerWithMaster()
@@ -419,6 +423,7 @@ private[spark] class Worker(
419423
registrationRetryTimer.foreach(_.cancel())
420424
executors.values.foreach(_.kill())
421425
drivers.values.foreach(_.kill())
426+
shuffleService.stop()
422427
webUi.stop()
423428
metricsSystem.stop()
424429
}
@@ -441,7 +446,8 @@ private[spark] object Worker extends Logging {
441446
cores: Int,
442447
memory: Int,
443448
masterUrls: Array[String],
444-
workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
449+
workDir: String,
450+
workerNumber: Option[Int] = None): (ActorSystem, Int) = {
445451

446452
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
447453
val conf = new SparkConf

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ final class ShuffleBlockFetcherIterator(
9292
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
9393
* in case of a runtime exception when processing the current buffer.
9494
*/
95-
private[this] var currentResult: FetchResult = null
95+
@volatile private[this] var currentResult: FetchResult = null
9696

9797
/**
9898
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that

core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,6 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh
8989
}
9090
}
9191

92-
test("security mismatch app ids") {
93-
val conf0 = new SparkConf()
94-
.set("spark.authenticate", "true")
95-
.set("spark.authenticate.secret", "good")
96-
.set("spark.app.id", "app-id")
97-
val conf1 = conf0.clone.set("spark.app.id", "other-id")
98-
testConnection(conf0, conf1) match {
99-
case Success(_) => fail("Should have failed")
100-
case Failure(t) => t.getMessage should include ("SASL appId app-id did not match")
101-
}
102-
}
103-
10492
/**
10593
* Creates two servers with different configurations and sees if they can talk.
10694
* Returns Success() if they can transfer a block, and Failure() if the block transfer was failed

network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ public void encode(ByteBuf buf) {
5858

5959
public static SaslMessage decode(ByteBuf buf) {
6060
if (buf.readByte() != TAG_BYTE) {
61-
throw new IllegalStateException("Expected SaslMessage, received something else");
61+
throw new IllegalStateException("Expected SaslMessage, received something else"
62+
+ " (maybe your client does not have SASL enabled?)");
6263
}
6364

6465
int idLength = buf.readInt();

0 commit comments

Comments
 (0)