Skip to content

Commit f165b2b

Browse files
aarondavrxin
authored andcommitted
[SPARK-4188] [Core] Perform network-level retry of shuffle file fetches
This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException. This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load. TODO: - [x] unit tests - [x] put in ExternalShuffleClient too Author: Aaron Davidson <[email protected]> Closes #3101 from aarondav/retry and squashes the following commits: 72a2a32 [Aaron Davidson] Add that we should remove the condition around the retry thingy c7fd107 [Aaron Davidson] Fix unit tests e80e4c2 [Aaron Davidson] Address initial comments 6f594cd [Aaron Davidson] Fix unit test 05ff43c [Aaron Davidson] Add to external shuffle client and add unit test 66e5a24 [Aaron Davidson] [SPARK-4238] [Core] Perform network-level retry of shuffle file fetches
1 parent 6e9ef10 commit f165b2b

File tree

16 files changed

+668
-45
lines changed

16 files changed

+668
-45
lines changed

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal
2727
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
2828
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
2929
import org.apache.spark.network.server._
30-
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
30+
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
3131
import org.apache.spark.serializer.JavaSerializer
3232
import org.apache.spark.storage.{BlockId, StorageLevel}
3333
import org.apache.spark.util.Utils
@@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
7171
listener: BlockFetchingListener): Unit = {
7272
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
7373
try {
74-
val client = clientFactory.createClient(host, port)
75-
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
76-
.start(OpenBlocks(blockIds.map(BlockId.apply)))
74+
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
75+
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
76+
val client = clientFactory.createClient(host, port)
77+
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
78+
.start(OpenBlocks(blockIds.map(BlockId.apply)))
79+
}
80+
}
81+
82+
val maxRetries = transportConf.maxIORetries()
83+
if (maxRetries > 0) {
84+
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
85+
// a bug in this code. We should remove the if statement once we're sure of the stability.
86+
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
87+
} else {
88+
blockFetchStarter.createAndStart(blockIds, listener)
89+
}
7790
} catch {
7891
case e: Exception =>
7992
logError("Exception while beginning fetchBlocks", e)

network/common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.network.client;
1919

2020
import java.io.Closeable;
21+
import java.io.IOException;
2122
import java.util.UUID;
23+
import java.util.concurrent.ExecutionException;
2224
import java.util.concurrent.TimeUnit;
2325

2426
import com.google.common.base.Objects;
@@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
116118
serverAddr, future.cause());
117119
logger.error(errorMsg, future.cause());
118120
handler.removeFetchRequest(streamChunkId);
119-
callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
120121
channel.close();
122+
try {
123+
callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
124+
} catch (Exception e) {
125+
logger.error("Uncaught exception in RPC response callback handler!", e);
126+
}
121127
}
122128
}
123129
});
@@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
147153
serverAddr, future.cause());
148154
logger.error(errorMsg, future.cause());
149155
handler.removeRpcRequest(requestId);
150-
callback.onFailure(new RuntimeException(errorMsg, future.cause()));
151156
channel.close();
157+
try {
158+
callback.onFailure(new IOException(errorMsg, future.cause()));
159+
} catch (Exception e) {
160+
logger.error("Uncaught exception in RPC response callback handler!", e);
161+
}
152162
}
153163
}
154164
});
@@ -175,6 +185,8 @@ public void onFailure(Throwable e) {
175185

176186
try {
177187
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
188+
} catch (ExecutionException e) {
189+
throw Throwables.propagate(e.getCause());
178190
} catch (Exception e) {
179191
throw Throwables.propagate(e);
180192
}

network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
package org.apache.spark.network.client;
1919

2020
import java.io.Closeable;
21+
import java.io.IOException;
2122
import java.lang.reflect.Field;
2223
import java.net.InetSocketAddress;
2324
import java.net.SocketAddress;
2425
import java.util.List;
2526
import java.util.concurrent.ConcurrentHashMap;
26-
import java.util.concurrent.TimeoutException;
2727
import java.util.concurrent.atomic.AtomicReference;
2828

2929
import com.google.common.base.Preconditions;
@@ -44,7 +44,6 @@
4444
import org.apache.spark.network.TransportContext;
4545
import org.apache.spark.network.server.TransportChannelHandler;
4646
import org.apache.spark.network.util.IOMode;
47-
import org.apache.spark.network.util.JavaUtils;
4847
import org.apache.spark.network.util.NettyUtils;
4948
import org.apache.spark.network.util.TransportConf;
5049

@@ -93,15 +92,17 @@ public TransportClientFactory(
9392
*
9493
* Concurrency: This method is safe to call from multiple threads.
9594
*/
96-
public TransportClient createClient(String remoteHost, int remotePort) {
95+
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
9796
// Get connection from the connection pool first.
9897
// If it is not found or not active, create a new one.
9998
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
10099
TransportClient cachedClient = connectionPool.get(address);
101100
if (cachedClient != null) {
102101
if (cachedClient.isActive()) {
102+
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
103103
return cachedClient;
104104
} else {
105+
logger.info("Found inactive connection to {}, closing it.", address);
105106
connectionPool.remove(address, cachedClient); // Remove inactive clients.
106107
}
107108
}
@@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) {
133134
long preConnect = System.currentTimeMillis();
134135
ChannelFuture cf = bootstrap.connect(address);
135136
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
136-
throw new RuntimeException(
137+
throw new IOException(
137138
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
138139
} else if (cf.cause() != null) {
139-
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
140+
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
140141
}
141142

142143
TransportClient client = clientRef.get();
@@ -198,7 +199,7 @@ public void close() {
198199
*/
199200
private PooledByteBufAllocator createPooledByteBufAllocator() {
200201
return new PooledByteBufAllocator(
201-
PlatformDependent.directBufferPreferred(),
202+
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
202203
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
203204
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
204205
getPrivateStaticField("DEFAULT_PAGE_SIZE"),

network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network.client;
1919

20+
import java.io.IOException;
2021
import java.util.Map;
2122
import java.util.concurrent.ConcurrentHashMap;
2223

@@ -94,7 +95,7 @@ public void channelUnregistered() {
9495
String remoteAddress = NettyUtils.getRemoteAddress(channel);
9596
logger.error("Still have {} requests outstanding when connection from {} is closed",
9697
numOutstandingRequests(), remoteAddress);
97-
failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
98+
failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
9899
}
99100
}
100101

network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
6666
// All messages have the frame length, message type, and message itself.
6767
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
6868
long frameLength = headerLength + bodyLength;
69-
ByteBuf header = ctx.alloc().buffer(headerLength);
69+
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
7070
header.writeLong(frameLength);
7171
msgType.encode(header);
7272
in.encode(header);

network/common/src/main/java/org/apache/spark/network/server/TransportServer.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.netty.channel.ChannelOption;
2929
import io.netty.channel.EventLoopGroup;
3030
import io.netty.channel.socket.SocketChannel;
31+
import io.netty.util.internal.PlatformDependent;
3132
import org.slf4j.Logger;
3233
import org.slf4j.LoggerFactory;
3334

@@ -71,11 +72,14 @@ private void init(int portToBind) {
7172
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
7273
EventLoopGroup workerGroup = bossGroup;
7374

75+
PooledByteBufAllocator allocator = new PooledByteBufAllocator(
76+
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());
77+
7478
bootstrap = new ServerBootstrap()
7579
.group(bossGroup, workerGroup)
7680
.channel(NettyUtils.getServerChannelClass(ioMode))
77-
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
78-
.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
81+
.option(ChannelOption.ALLOCATOR, allocator)
82+
.childOption(ChannelOption.ALLOCATOR, allocator);
7983

8084
if (conf.backLog() > 0) {
8185
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());

network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@
3737
* Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
3838
*/
3939
public class NettyUtils {
40-
/** Creates a Netty EventLoopGroup based on the IOMode. */
41-
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
42-
43-
ThreadFactory threadFactory = new ThreadFactoryBuilder()
40+
/** Creates a new ThreadFactory which prefixes each thread with the given name. */
41+
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
42+
return new ThreadFactoryBuilder()
4443
.setDaemon(true)
45-
.setNameFormat(threadPrefix + "-%d")
44+
.setNameFormat(threadPoolPrefix + "-%d")
4645
.build();
46+
}
47+
48+
/** Creates a Netty EventLoopGroup based on the IOMode. */
49+
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
50+
ThreadFactory threadFactory = createThreadFactory(threadPrefix);
4751

4852
switch (mode) {
4953
case NIO:

network/common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) {
3030
/** IO mode: nio or epoll */
3131
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
3232

33+
/** If true, we will prefer allocating off-heap byte buffers within Netty. */
34+
public boolean preferDirectBufs() {
35+
return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
36+
}
37+
3338
/** Connect timeout in secs. Default 120 secs. */
3439
public int connectionTimeoutMs() {
3540
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
@@ -58,4 +63,16 @@ public int connectionTimeoutMs() {
5863

5964
/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
6065
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
66+
67+
/**
68+
* Max number of times we will try IO exceptions (such as connection timeouts) per request.
69+
* If set to 0, we will not do any retries.
70+
*/
71+
public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); }
72+
73+
/**
74+
* Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
75+
* Only relevant if maxIORetries > 0.
76+
*/
77+
public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); }
6178
}

network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network;
1919

20+
import java.io.IOException;
2021
import java.util.concurrent.TimeoutException;
2122

2223
import org.junit.After;
@@ -57,7 +58,7 @@ public void tearDown() {
5758
}
5859

5960
@Test
60-
public void createAndReuseBlockClients() throws TimeoutException {
61+
public void createAndReuseBlockClients() throws IOException {
6162
TransportClientFactory factory = context.createClientFactory();
6263
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
6364
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
@@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException {
7071
}
7172

7273
@Test
73-
public void neverReturnInactiveClients() throws Exception {
74+
public void neverReturnInactiveClients() throws IOException, InterruptedException {
7475
TransportClientFactory factory = context.createClientFactory();
7576
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
7677
c1.close();
@@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception {
8889
}
8990

9091
@Test
91-
public void closeBlockClientsWithFactory() throws TimeoutException {
92+
public void closeBlockClientsWithFactory() throws IOException {
9293
TransportClientFactory factory = context.createClientFactory();
9394
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
9495
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());

network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network.shuffle;
1919

20+
import java.io.IOException;
2021
import java.util.List;
2122

2223
import com.google.common.collect.Lists;
@@ -76,17 +77,33 @@ public void init(String appId) {
7677

7778
@Override
7879
public void fetchBlocks(
79-
String host,
80-
int port,
81-
String execId,
80+
final String host,
81+
final int port,
82+
final String execId,
8283
String[] blockIds,
8384
BlockFetchingListener listener) {
8485
assert appId != null : "Called before init()";
8586
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
8687
try {
87-
TransportClient client = clientFactory.createClient(host, port);
88-
new OneForOneBlockFetcher(client, blockIds, listener)
89-
.start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
88+
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
89+
new RetryingBlockFetcher.BlockFetchStarter() {
90+
@Override
91+
public void createAndStart(String[] blockIds, BlockFetchingListener listener)
92+
throws IOException {
93+
TransportClient client = clientFactory.createClient(host, port);
94+
new OneForOneBlockFetcher(client, blockIds, listener)
95+
.start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
96+
}
97+
};
98+
99+
int maxRetries = conf.maxIORetries();
100+
if (maxRetries > 0) {
101+
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
102+
// a bug in this code. We should remove the if statement once we're sure of the stability.
103+
new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
104+
} else {
105+
blockFetchStarter.createAndStart(blockIds, listener);
106+
}
90107
} catch (Exception e) {
91108
logger.error("Exception while beginning fetchBlocks", e);
92109
for (String blockId : blockIds) {
@@ -108,7 +125,7 @@ public void registerWithShuffleServer(
108125
String host,
109126
int port,
110127
String execId,
111-
ExecutorShuffleInfo executorInfo) {
128+
ExecutorShuffleInfo executorInfo) throws IOException {
112129
assert appId != null : "Called before init()";
113130
TransportClient client = clientFactory.createClient(host, port);
114131
byte[] registerExecutorMessage =

0 commit comments

Comments
 (0)