Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -41,6 +46,13 @@
public class MesosExternalShuffleClient extends ExternalShuffleClient {
private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class);

private final ScheduledExecutorService heartbeaterThread =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("mesos-external-shuffle-client-heartbeater")
.build());

/**
* Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}.
* Please refer to docs on {@link ExternalShuffleClient} for more information.
Expand All @@ -53,21 +65,59 @@ public MesosExternalShuffleClient(
super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
}

public void registerDriverWithShuffleService(String host, int port) throws IOException {
public void registerDriverWithShuffleService(
String host,
int port,
long heartbeatTimeoutMs,
long heartbeatIntervalMs) throws IOException {

checkInit();
ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer();
TransportClient client = clientFactory.createClient(host, port);
client.sendRpc(registerDriver, new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
logger.info("Successfully registered app " + appId + " with external shuffle service.");
}

@Override
public void onFailure(Throwable e) {
logger.warn("Unable to register app " + appId + " with external shuffle service. " +
client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs));
}

private class RegisterDriverCallback implements RpcResponseCallback {
private final TransportClient client;
private final long heartbeatIntervalMs;

private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) {
this.client = client;
this.heartbeatIntervalMs = heartbeatIntervalMs;
}

@Override
public void onSuccess(ByteBuffer response) {
heartbeaterThread.scheduleAtFixedRate(
new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS);
logger.info("Successfully registered app " + appId + " with external shuffle service.");
}

@Override
public void onFailure(Throwable e) {
logger.warn("Unable to register app " + appId + " with external shuffle service. " +
"Please manually remove shuffle data after driver exit. Error: " + e);
}
});
}
}

@Override
public void close() {
heartbeaterThread.shutdownNow();
super.close();
}

private class Heartbeater implements Runnable {

private final TransportClient client;

private Heartbeater(TransportClient client) {
this.client = client;
}

@Override
public void run() {
// TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout
client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import org.apache.spark.network.protocol.Encodable;
import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;

/**
* Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
Expand All @@ -40,7 +41,8 @@ public abstract class BlockTransferMessage implements Encodable {

/** Preceding every serialized message is its type, which allows us to deserialize it. */
public static enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4);
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5);

private final byte id;

Expand All @@ -64,6 +66,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
case 2: return RegisterExecutor.decode(buf);
case 3: return StreamHandle.decode(buf);
case 4: return RegisterDriver.decode(buf);
case 5: return ShuffleServiceHeartbeat.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,34 @@
*/
public class RegisterDriver extends BlockTransferMessage {
private final String appId;
private final long heartbeatTimeoutMs;

public RegisterDriver(String appId) {
public RegisterDriver(String appId, long heartbeatTimeoutMs) {
this.appId = appId;
this.heartbeatTimeoutMs = heartbeatTimeoutMs;
}

public String getAppId() { return appId; }

public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; }

@Override
protected Type type() { return Type.REGISTER_DRIVER; }

@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId);
return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE;
}

@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
buf.writeLong(heartbeatTimeoutMs);
}

@Override
public int hashCode() {
return Objects.hashCode(appId);
return Objects.hashCode(appId, heartbeatTimeoutMs);
}

@Override
Expand All @@ -66,6 +71,7 @@ public boolean equals(Object o) {

public static RegisterDriver decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
return new RegisterDriver(appId);
long heartbeatTimeout = buf.readLong();
return new RegisterDriver(appId, heartbeatTimeout);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle.protocol.mesos;

import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encoders;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;

// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;

/**
* A heartbeat sent from the driver to the MesosExternalShuffleService.
*/
public class ShuffleServiceHeartbeat extends BlockTransferMessage {
private final String appId;

public ShuffleServiceHeartbeat(String appId) {
this.appId = appId;
}

public String getAppId() { return appId; }

@Override
protected Type type() { return Type.HEARTBEAT; }

@Override
public int encodedLength() { return Encoders.Strings.encodedLength(appId); }

@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
}

public static ShuffleServiceHeartbeat decode(ByteBuf buf) {
return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,89 @@

package org.apache.spark.deploy.mesos

import java.net.SocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}

import scala.collection.mutable
import scala.collection.JavaConverters._

import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat}
import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.ThreadUtils

/**
* An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
* It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
*/
private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
private[mesos] class MesosExternalShuffleBlockHandler(
transportConf: TransportConf,
cleanerIntervalS: Long)
extends ExternalShuffleBlockHandler(transportConf, null) with Logging {

// Stores a map of driver socket addresses to app ids
private val connectedApps = new mutable.HashMap[SocketAddress, String]
ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher")
.scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS)

// Stores a map of app id to app state (timeout value and last heartbeat)
private val connectedApps = new ConcurrentHashMap[String, AppState]()

protected override def handleMessage(
message: BlockTransferMessage,
client: TransportClient,
callback: RpcResponseCallback): Unit = {
message match {
case RegisterDriverParam(appId) =>
case RegisterDriverParam(appId, appState) =>
val address = client.getSocketAddress
logDebug(s"Received registration request from app $appId (remote address $address).")
if (connectedApps.contains(address)) {
val existingAppId = connectedApps(address)
if (!existingAppId.equals(appId)) {
logError(s"A new app '$appId' has connected to existing address $address, " +
s"removing previously registered app '$existingAppId'.")
applicationRemoved(existingAppId, true)
}
val timeout = appState.heartbeatTimeout
logInfo(s"Received registration request from app $appId (remote address $address, " +
s"heartbeat timeout $timeout ms).")
if (connectedApps.containsKey(appId)) {
logWarning(s"Received a registration request from app $appId, but it was already " +
s"registered")
}
connectedApps(address) = appId
connectedApps.put(appId, appState)
callback.onSuccess(ByteBuffer.allocate(0))
case Heartbeat(appId) =>
val address = client.getSocketAddress
Option(connectedApps.get(appId)) match {
case Some(existingAppState) =>
logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " +
s"address $address).")
existingAppState.lastHeartbeat = System.nanoTime()
case None =>
logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " +
s"address $address, appId '$appId').")
}
case _ => super.handleMessage(message, client, callback)
}
}

/**
* On connection termination, clean up shuffle files written by the associated application.
*/
override def channelInactive(client: TransportClient): Unit = {
val address = client.getSocketAddress
if (connectedApps.contains(address)) {
val appId = connectedApps(address)
logInfo(s"Application $appId disconnected (address was $address).")
applicationRemoved(appId, true /* cleanupLocalDirs */)
connectedApps.remove(address)
} else {
logWarning(s"Unknown $address disconnected.")
}
}

/** An extractor object for matching [[RegisterDriver]] message. */
private object RegisterDriverParam {
def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
def unapply(r: RegisterDriver): Option[(String, AppState)] =
Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime())))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I have a shuffle service running 1.6.0 and my Spark application is running 1.6.1? We'll get some kind of version error here right? Is the idea that this is OK since using shuffle service with Spark on Mesos in 1.6.0 is completely broken anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to do the backport for 1.6 in a separate PR, since what you mention here (compatibility of patched and unpatched spark versions) needs to be considered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. So, in the case of Shuffle Server with 1.6 and newer driver:

  • the shuffle server will see a RegisterDriver message that has an additional parameter. the old shuffle service will decode it without looking at the timeout value, and registration will continue as usual
  • heartbeat messages, if they arrive, are going to be dropped as "unknown message"

But, since Spark on Mesos in 1.6.0 is broken in this way, I would say this isn't a setup we should be very concerned with.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, things work a bit better than I thought. The shuffle service will log errors like the one below, but the connection won't timeout anymore, since the driver is constantly sending (misunderstood) heartbeat messages.

16/02/23 17:13:28 ERROR TransportRequestHandler: Error while invoking RpcHandler#receive() for one-way message.
java.lang.IllegalArgumentException: Unknown message type: 5
    at org.apache.spark.network.shuffle.protocol.BlockTransferMessage$Decoder.fromByteBuffer(BlockTransferMessage.java:67)
    at org.apache.spark.network.shuffle.ExternalShuffleBlockHandler.receive(ExternalShuffleBlockHandler.java:71)
    at org.apache.spark.network.server.RpcHandler.receive(RpcHandler.java:68)
    at org.apache.spark.network.server.TransportRequestHandler.processOneWayMessage(TransportRequestHandler.java:180)
    at org.apache.spark.network.server.TransportRequestHandler.handle(TransportRequestHandler.java:109)
    at org.apache.spark.network.server.TransportChannelHandler.channelRead0(TransportChannelHandler.java:119)
    at org.apache.spark.network.server.TransportChannelHandler.channelRead0(TransportChannelHandler.java:51)
    at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO for 1.6: Need to handle that the RegisterDriver message might not contain the timeout value. The registration of the MesosShuffleClient with the ShuffleService should fail in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is the other way around: a newer shuffle service, and an older driver? That won't work, since RegisterDriver.decode would fail when trying to read the (non-existing) timeout value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bbossy can you elaborate on your proposed fix for 1.6.2? How would it be different from this patch in its current form?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a quick look at this. What could be done is to send a more informative Throwable (using callback.onFailure) when the driver registers with the shuffle service, but the message does not contain the heartbeatTimeoutMs.

}

private object Heartbeat {
def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId)
}

private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long)

private class CleanerThread extends Runnable {
override def run(): Unit = {
val now = System.nanoTime()
connectedApps.asScala.foreach { case (appId, appState) =>
if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) {
logInfo(s"Application $appId timed out. Removing shuffle files.")
connectedApps.remove(appId)
applicationRemoved(appId, true)
}
}
}
}
}

Expand All @@ -93,7 +113,8 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage

protected override def newShuffleBlockHandler(
conf: TransportConf): ExternalShuffleBlockHandler = {
new MesosExternalShuffleBlockHandler(conf)
val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s")
new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,12 @@ private[spark] class CoarseMesosSchedulerBackend(
s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}")

mesosExternalShuffleClient.get
.registerDriverWithShuffleService(slave.hostname, externalShufflePort)
.registerDriverWithShuffleService(
slave.hostname,
externalShufflePort,
sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs",
s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"),
sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s"))
slave.shuffleRegistered = true
}

Expand Down Expand Up @@ -506,6 +511,9 @@ private[spark] class CoarseMesosSchedulerBackend(
+ "on the mesos nodes.")
}

// Close the mesos external shuffle client if used
mesosExternalShuffleClient.foreach(_.close())

if (mesosDriver != null) {
mesosDriver.stop()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite

val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING)
backend.statusUpdate(driver, status2)
verify(externalShuffleClient, times(1)).registerDriverWithShuffleService(anyString, anyInt)
verify(externalShuffleClient, times(1))
.registerDriverWithShuffleService(anyString, anyInt, anyLong, anyLong)
}

test("mesos kills an executor when told") {
Expand Down