Skip to content

Commit 799e131

Browse files
jinxingcloud-fan
authored andcommitted
[SPARK-21175] Reject OpenBlocks when memory shortage on shuffle service.
## What changes were proposed in this pull request? A shuffle service can serves blocks from multiple apps/tasks. Thus the shuffle service can suffers high memory usage when lots of shuffle-reads happen at the same time. In my cluster, OOM always happens on shuffle service. Analyzing heap dump, memory cost by Netty(ChannelOutboundBufferEntry) can be up to 2~3G. It might make sense to reject "open blocks" request when memory usage is high on shuffle service. 93dd0c5 and 85c6ce6 tried to alleviate the memory pressure on shuffle service but cannot solve the root cause. This pr proposes to control currency of shuffle read. ## How was this patch tested? Added unit test. Author: jinxing <[email protected]> Closes #18388 from jinxing64/SPARK-21175.
1 parent 996a809 commit 799e131

File tree

7 files changed

+265
-13
lines changed

7 files changed

+265
-13
lines changed

common/network-common/src/main/java/org/apache/spark/network/TransportContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler
168168
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
169169
TransportClient client = new TransportClient(channel, responseHandler);
170170
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
171-
rpcHandler);
171+
rpcHandler, conf.maxChunksBeingTransferred());
172172
return new TransportChannelHandler(client, responseHandler, requestHandler,
173173
conf.connectionTimeoutMs(), closeIdleConnections);
174174
}

common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import com.google.common.base.Preconditions;
2727
import io.netty.channel.Channel;
28+
import org.apache.commons.lang3.tuple.ImmutablePair;
29+
import org.apache.commons.lang3.tuple.Pair;
2830
import org.slf4j.Logger;
2931
import org.slf4j.LoggerFactory;
3032

@@ -53,6 +55,9 @@ private static class StreamState {
5355
// that the caller only requests each chunk one at a time, in order.
5456
int curChunk = 0;
5557

58+
// Used to keep track of the number of chunks being transferred and not finished yet.
59+
volatile long chunksBeingTransferred = 0L;
60+
5661
StreamState(String appId, Iterator<ManagedBuffer> buffers) {
5762
this.appId = appId;
5863
this.buffers = Preconditions.checkNotNull(buffers);
@@ -96,18 +101,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
96101

97102
@Override
98103
public ManagedBuffer openStream(String streamChunkId) {
99-
String[] array = streamChunkId.split("_");
100-
assert array.length == 2:
101-
"Stream id and chunk index should be specified when open stream for fetching block.";
102-
long streamId = Long.valueOf(array[0]);
103-
int chunkIndex = Integer.valueOf(array[1]);
104-
return getChunk(streamId, chunkIndex);
104+
Pair<Long, Integer> streamChunkIdPair = parseStreamChunkId(streamChunkId);
105+
return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight());
105106
}
106107

107108
public static String genStreamChunkId(long streamId, int chunkId) {
108109
return String.format("%d_%d", streamId, chunkId);
109110
}
110111

112+
// Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a
113+
// stream.
114+
public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) {
115+
String[] array = streamChunkId.split("_");
116+
assert array.length == 2:
117+
"Stream id and chunk index should be specified.";
118+
long streamId = Long.valueOf(array[0]);
119+
int chunkIndex = Integer.valueOf(array[1]);
120+
return ImmutablePair.of(streamId, chunkIndex);
121+
}
122+
111123
@Override
112124
public void connectionTerminated(Channel channel) {
113125
// Close all streams which have been associated with the channel.
@@ -139,6 +151,42 @@ public void checkAuthorization(TransportClient client, long streamId) {
139151
}
140152
}
141153

154+
@Override
155+
public void chunkBeingSent(long streamId) {
156+
StreamState streamState = streams.get(streamId);
157+
if (streamState != null) {
158+
streamState.chunksBeingTransferred++;
159+
}
160+
161+
}
162+
163+
@Override
164+
public void streamBeingSent(String streamId) {
165+
chunkBeingSent(parseStreamChunkId(streamId).getLeft());
166+
}
167+
168+
@Override
169+
public void chunkSent(long streamId) {
170+
StreamState streamState = streams.get(streamId);
171+
if (streamState != null) {
172+
streamState.chunksBeingTransferred--;
173+
}
174+
}
175+
176+
@Override
177+
public void streamSent(String streamId) {
178+
chunkSent(OneForOneStreamManager.parseStreamChunkId(streamId).getLeft());
179+
}
180+
181+
@Override
182+
public long chunksBeingTransferred() {
183+
long sum = 0L;
184+
for (StreamState streamState: streams.values()) {
185+
sum += streamState.chunksBeingTransferred;
186+
}
187+
return sum;
188+
}
189+
142190
/**
143191
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
144192
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a

common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,31 @@ public void connectionTerminated(Channel channel) { }
8383
*/
8484
public void checkAuthorization(TransportClient client, long streamId) { }
8585

86+
/**
87+
* Return the number of chunks being transferred and not finished yet in this StreamManager.
88+
*/
89+
public long chunksBeingTransferred() {
90+
return 0;
91+
}
92+
93+
/**
94+
* Called when start sending a chunk.
95+
*/
96+
public void chunkBeingSent(long streamId) { }
97+
98+
/**
99+
* Called when start sending a stream.
100+
*/
101+
public void streamBeingSent(String streamId) { }
102+
103+
/**
104+
* Called when a chunk is successfully sent.
105+
*/
106+
public void chunkSent(long streamId) { }
107+
108+
/**
109+
* Called when a stream is successfully sent.
110+
*/
111+
public void streamSent(String streamId) { }
112+
86113
}

common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import com.google.common.base.Throwables;
2424
import io.netty.channel.Channel;
25+
import io.netty.channel.ChannelFuture;
2526
import org.slf4j.Logger;
2627
import org.slf4j.LoggerFactory;
2728

@@ -65,14 +66,19 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
6566
/** Returns each chunk part of a stream. */
6667
private final StreamManager streamManager;
6768

69+
/** The max number of chunks being transferred and not finished yet. */
70+
private final long maxChunksBeingTransferred;
71+
6872
public TransportRequestHandler(
6973
Channel channel,
7074
TransportClient reverseClient,
71-
RpcHandler rpcHandler) {
75+
RpcHandler rpcHandler,
76+
Long maxChunksBeingTransferred) {
7277
this.channel = channel;
7378
this.reverseClient = reverseClient;
7479
this.rpcHandler = rpcHandler;
7580
this.streamManager = rpcHandler.getStreamManager();
81+
this.maxChunksBeingTransferred = maxChunksBeingTransferred;
7682
}
7783

7884
@Override
@@ -117,7 +123,13 @@ private void processFetchRequest(final ChunkFetchRequest req) {
117123
logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel),
118124
req.streamChunkId);
119125
}
120-
126+
long chunksBeingTransferred = streamManager.chunksBeingTransferred();
127+
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
128+
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
129+
chunksBeingTransferred, maxChunksBeingTransferred);
130+
channel.close();
131+
return;
132+
}
121133
ManagedBuffer buf;
122134
try {
123135
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
@@ -130,10 +142,25 @@ private void processFetchRequest(final ChunkFetchRequest req) {
130142
return;
131143
}
132144

133-
respond(new ChunkFetchSuccess(req.streamChunkId, buf));
145+
streamManager.chunkBeingSent(req.streamChunkId.streamId);
146+
respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> {
147+
streamManager.chunkSent(req.streamChunkId.streamId);
148+
});
134149
}
135150

136151
private void processStreamRequest(final StreamRequest req) {
152+
if (logger.isTraceEnabled()) {
153+
logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel),
154+
req.streamId);
155+
}
156+
157+
long chunksBeingTransferred = streamManager.chunksBeingTransferred();
158+
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
159+
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
160+
chunksBeingTransferred, maxChunksBeingTransferred);
161+
channel.close();
162+
return;
163+
}
137164
ManagedBuffer buf;
138165
try {
139166
buf = streamManager.openStream(req.streamId);
@@ -145,7 +172,10 @@ private void processStreamRequest(final StreamRequest req) {
145172
}
146173

147174
if (buf != null) {
148-
respond(new StreamResponse(req.streamId, buf.size(), buf));
175+
streamManager.streamBeingSent(req.streamId);
176+
respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> {
177+
streamManager.streamSent(req.streamId);
178+
});
149179
} else {
150180
respond(new StreamFailure(req.streamId, String.format(
151181
"Stream '%s' was not found.", req.streamId)));
@@ -187,9 +217,9 @@ private void processOneWayMessage(OneWayMessage req) {
187217
* Responds to a single message with some Encodable object. If a failure occurs while sending,
188218
* it will be logged and the channel closed.
189219
*/
190-
private void respond(Encodable result) {
220+
private ChannelFuture respond(Encodable result) {
191221
SocketAddress remoteAddress = channel.remoteAddress();
192-
channel.writeAndFlush(result).addListener(future -> {
222+
return channel.writeAndFlush(result).addListener(future -> {
193223
if (future.isSuccess()) {
194224
logger.trace("Sent result {} to client {}", result, remoteAddress);
195225
} else {

common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,10 @@ public Properties cryptoConf() {
257257
return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll());
258258
}
259259

260+
/**
261+
* The max number of chunks allowed to being transferred at the same time on shuffle service.
262+
*/
263+
public long maxChunksBeingTransferred() {
264+
return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE);
265+
}
260266
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network;
19+
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
23+
import io.netty.channel.Channel;
24+
import io.netty.channel.ChannelPromise;
25+
import io.netty.channel.DefaultChannelPromise;
26+
import io.netty.util.concurrent.Future;
27+
import io.netty.util.concurrent.GenericFutureListener;
28+
import org.junit.Test;
29+
30+
import static org.mockito.Mockito.*;
31+
32+
import org.apache.commons.lang3.tuple.ImmutablePair;
33+
import org.apache.commons.lang3.tuple.Pair;
34+
import org.apache.spark.network.buffer.ManagedBuffer;
35+
import org.apache.spark.network.client.TransportClient;
36+
import org.apache.spark.network.protocol.*;
37+
import org.apache.spark.network.server.NoOpRpcHandler;
38+
import org.apache.spark.network.server.OneForOneStreamManager;
39+
import org.apache.spark.network.server.RpcHandler;
40+
import org.apache.spark.network.server.TransportRequestHandler;
41+
42+
public class TransportRequestHandlerSuite {
43+
44+
@Test
45+
public void handleFetchRequestAndStreamRequest() throws Exception {
46+
RpcHandler rpcHandler = new NoOpRpcHandler();
47+
OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
48+
Channel channel = mock(Channel.class);
49+
List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs =
50+
new ArrayList<>();
51+
when(channel.writeAndFlush(any()))
52+
.thenAnswer(invocationOnMock0 -> {
53+
Object response = invocationOnMock0.getArguments()[0];
54+
ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
55+
responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
56+
return channelFuture;
57+
});
58+
59+
// Prepare the stream.
60+
List<ManagedBuffer> managedBuffers = new ArrayList<>();
61+
managedBuffers.add(new TestManagedBuffer(10));
62+
managedBuffers.add(new TestManagedBuffer(20));
63+
managedBuffers.add(new TestManagedBuffer(30));
64+
managedBuffers.add(new TestManagedBuffer(40));
65+
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
66+
streamManager.registerChannel(channel, streamId);
67+
TransportClient reverseClient = mock(TransportClient.class);
68+
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
69+
rpcHandler, 2L);
70+
71+
RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0));
72+
requestHandler.handle(request0);
73+
assert responseAndPromisePairs.size() == 1;
74+
assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess;
75+
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() ==
76+
managedBuffers.get(0);
77+
78+
RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1));
79+
requestHandler.handle(request1);
80+
assert responseAndPromisePairs.size() == 2;
81+
assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess;
82+
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() ==
83+
managedBuffers.get(1);
84+
85+
// Finish flushing the response for request0.
86+
responseAndPromisePairs.get(0).getRight().finish(true);
87+
88+
RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
89+
requestHandler.handle(request2);
90+
assert responseAndPromisePairs.size() == 3;
91+
assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse;
92+
assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() ==
93+
managedBuffers.get(2);
94+
95+
// Request3 will trigger the close of channel, because the number of max chunks being
96+
// transferred is 2;
97+
RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3));
98+
requestHandler.handle(request3);
99+
verify(channel, times(1)).close();
100+
assert responseAndPromisePairs.size() == 3;
101+
}
102+
103+
private class ExtendedChannelPromise extends DefaultChannelPromise {
104+
105+
private List<GenericFutureListener> listeners = new ArrayList<>();
106+
private boolean success;
107+
108+
public ExtendedChannelPromise(Channel channel) {
109+
super(channel);
110+
success = false;
111+
}
112+
113+
@Override
114+
public ChannelPromise addListener(
115+
GenericFutureListener<? extends Future<? super Void>> listener) {
116+
listeners.add(listener);
117+
return super.addListener(listener);
118+
}
119+
120+
@Override
121+
public boolean isSuccess() {
122+
return success;
123+
}
124+
125+
public void finish(boolean success) {
126+
this.success = success;
127+
listeners.forEach(listener -> {
128+
try {
129+
listener.operationComplete(this);
130+
} catch (Exception e) { }
131+
});
132+
}
133+
}
134+
}

docs/configuration.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,13 @@ Apart from these, the following properties are also available, and may be useful
631631
Max number of entries to keep in the index cache of the shuffle service.
632632
</td>
633633
</tr>
634+
<tr>
635+
<td><code>spark.shuffle.maxChunksBeingTransferred</code></td>
636+
<td>Long.MAX_VALUE</td>
637+
<td>
638+
The max number of chunks allowed to being transferred at the same time on shuffle service.
639+
</td>
640+
</tr>
634641
<tr>
635642
<td><code>spark.shuffle.sort.bypassMergeThreshold</code></td>
636643
<td>200</td>

0 commit comments

Comments
 (0)