Skip to content

Commit 1a26c7b

Browse files
Aravind PatnamMridul Muralidharan
authored andcommitted
[SPARK-41415][3.2] SASL Request Retries
Add the ability to retry SASL requests. Will add it as a metric too soon to track SASL retries. We are seeing increased SASL timeouts internally, and this issue would mitigate the issue. We already have this feature enabled for our 2.3 jobs, and we have seen failures significantly decrease. No Added unit tests, and tested on cluster to ensure the retries are being triggered correctly. Closes apache#38959 from akpatnam25/SPARK-41415. Authored-by: Aravind Patnam <apatnamlinkedin.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> Closes apache#39645 from akpatnam25/SPARK-41415-backport-3.2. Authored-by: Aravind Patnam <[email protected]> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
1 parent 68fb5c4 commit 1a26c7b

File tree

5 files changed

+201
-9
lines changed

5 files changed

+201
-9
lines changed

common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.io.IOException;
2121
import java.nio.ByteBuffer;
22+
import java.util.concurrent.TimeoutException;
2223
import javax.security.sasl.Sasl;
2324
import javax.security.sasl.SaslException;
2425

@@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) {
6566
SaslMessage msg = new SaslMessage(appId, payload);
6667
ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
6768
msg.encode(buf);
69+
ByteBuffer response;
6870
buf.writeBytes(msg.body().nioByteBuffer());
69-
70-
ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
71+
try {
72+
response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
73+
} catch (RuntimeException ex) {
74+
// We know it is a Sasl timeout here if it is a TimeoutException.
75+
if (ex.getCause() instanceof TimeoutException) {
76+
throw new SaslTimeoutException(ex.getCause());
77+
} else {
78+
throw ex;
79+
}
80+
}
7181
payload = saslClient.response(JavaUtils.bufferToArray(response));
7282
}
7383

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.sasl;
19+
20+
/**
21+
* An exception thrown if there is a SASL timeout.
22+
*/
23+
public class SaslTimeoutException extends RuntimeException {
24+
public SaslTimeoutException(Throwable cause) {
25+
super(cause);
26+
}
27+
28+
public SaslTimeoutException(String message) {
29+
super(message);
30+
}
31+
32+
public SaslTimeoutException(String message, Throwable cause) {
33+
super(message, cause);
34+
}
35+
}

common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ public boolean useOldFetchProtocol() {
374374
return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
375375
}
376376

377+
/** Whether to enable sasl retries or not. The number of retries is dictated by the config
378+
* `spark.shuffle.io.maxRetries`.
379+
*/
380+
public boolean enableSaslRetries() {
381+
return conf.getBoolean("spark.shuffle.sasl.enableRetries", false);
382+
}
383+
377384
/**
378385
* Class name of the implementation of MergedShuffleFileManager that merges the blocks
379386
* pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
import java.util.concurrent.Executors;
2525
import java.util.concurrent.TimeUnit;
2626

27+
import com.google.common.annotations.VisibleForTesting;
2728
import com.google.common.collect.Sets;
2829
import com.google.common.util.concurrent.Uninterruptibles;
2930
import org.slf4j.Logger;
3031
import org.slf4j.LoggerFactory;
3132

3233
import org.apache.spark.network.buffer.ManagedBuffer;
34+
import org.apache.spark.network.sasl.SaslTimeoutException;
3335
import org.apache.spark.network.util.NettyUtils;
3436
import org.apache.spark.network.util.TransportConf;
3537

@@ -85,6 +87,8 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
8587
/** Number of times we've attempted to retry so far. */
8688
private int retryCount = 0;
8789

90+
private boolean saslTimeoutSeen;
91+
8892
/**
8993
* Set of all block ids which have not been transferred successfully or with a non-IO Exception.
9094
* A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
@@ -99,6 +103,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
99103
*/
100104
private RetryingBlockTransferListener currentListener;
101105

106+
/** Whether sasl retries are enabled. */
107+
private final boolean enableSaslRetries;
108+
102109
private final ErrorHandler errorHandler;
103110

104111
public RetryingBlockTransferor(
@@ -115,6 +122,8 @@ public RetryingBlockTransferor(
115122
Collections.addAll(outstandingBlocksIds, blockIds);
116123
this.currentListener = new RetryingBlockTransferListener();
117124
this.errorHandler = errorHandler;
125+
this.enableSaslRetries = conf.enableSaslRetries();
126+
this.saslTimeoutSeen = false;
118127
}
119128

120129
public RetryingBlockTransferor(
@@ -187,13 +196,29 @@ private synchronized void initiateRetry() {
187196

188197
/**
189198
* Returns true if we should retry due a block transfer failure. We will retry if and only if
190-
* the exception was an IOException and we haven't retried 'maxRetries' times already.
199+
* the exception was an IOException or SaslTimeoutException and we haven't retried
200+
* 'maxRetries' times already.
191201
*/
192202
private synchronized boolean shouldRetry(Throwable e) {
193203
boolean isIOException = e instanceof IOException
194-
|| (e.getCause() != null && e.getCause() instanceof IOException);
204+
|| e.getCause() instanceof IOException;
205+
boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException;
206+
if (!isSaslTimeout && saslTimeoutSeen) {
207+
retryCount = 0;
208+
saslTimeoutSeen = false;
209+
}
195210
boolean hasRemainingRetries = retryCount < maxRetries;
196-
return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e);
211+
boolean shouldRetry = (isSaslTimeout || isIOException) &&
212+
hasRemainingRetries && errorHandler.shouldRetryError(e);
213+
if (shouldRetry && isSaslTimeout) {
214+
this.saslTimeoutSeen = true;
215+
}
216+
return shouldRetry;
217+
}
218+
219+
@VisibleForTesting
220+
public int getRetryCount() {
221+
return retryCount;
197222
}
198223

199224
/**
@@ -211,6 +236,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) {
211236
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
212237
outstandingBlocksIds.remove(blockId);
213238
shouldForwardSuccess = true;
239+
if (saslTimeoutSeen) {
240+
retryCount = 0;
241+
saslTimeoutSeen = false;
242+
}
214243
}
215244
}
216245

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@
2020

2121
import java.io.IOException;
2222
import java.nio.ByteBuffer;
23+
import java.util.ArrayList;
2324
import java.util.Arrays;
25+
import java.util.HashMap;
2426
import java.util.LinkedHashSet;
2527
import java.util.List;
2628
import java.util.Map;
29+
import java.util.concurrent.TimeoutException;
2730

2831
import com.google.common.collect.ImmutableMap;
2932
import com.google.common.collect.Sets;
33+
34+
import org.junit.Before;
3035
import org.junit.Test;
3136
import org.mockito.stubbing.Answer;
3237
import org.mockito.stubbing.Stubber;
@@ -38,6 +43,7 @@
3843
import org.apache.spark.network.buffer.NioManagedBuffer;
3944
import org.apache.spark.network.util.MapConfigProvider;
4045
import org.apache.spark.network.util.TransportConf;
46+
import org.apache.spark.network.sasl.SaslTimeoutException;
4147
import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter;
4248

4349
/**
@@ -49,6 +55,16 @@ public class RetryingBlockTransferorSuite {
4955
private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
5056
private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
5157
private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
58+
private static Map<String, String> configMap;
59+
private static RetryingBlockTransferor _retryingBlockTransferor;
60+
61+
@Before
62+
public void initMap() {
63+
configMap = new HashMap<String, String>() {{
64+
put("spark.shuffle.io.maxRetries", "2");
65+
put("spark.shuffle.io.retryWait", "0");
66+
}};
67+
}
5268

5369
@Test
5470
public void testNoFailures() throws IOException, InterruptedException {
@@ -230,6 +246,101 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException
230246
verifyNoMoreInteractions(listener);
231247
}
232248

249+
@Test
250+
public void testSaslTimeoutFailure() throws IOException, InterruptedException {
251+
BlockFetchingListener listener = mock(BlockFetchingListener.class);
252+
TimeoutException timeoutException = new TimeoutException();
253+
SaslTimeoutException saslTimeoutException =
254+
new SaslTimeoutException(timeoutException);
255+
List<? extends Map<String, Object>> interactions = Arrays.asList(
256+
ImmutableMap.<String, Object>builder()
257+
.put("b0", saslTimeoutException)
258+
.build(),
259+
ImmutableMap.<String, Object>builder()
260+
.put("b0", block0)
261+
.build()
262+
);
263+
264+
performInteractions(interactions, listener);
265+
266+
verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
267+
verify(listener).getTransferType();
268+
verifyNoMoreInteractions(listener);
269+
}
270+
271+
@Test
272+
public void testRetryOnSaslTimeout() throws IOException, InterruptedException {
273+
BlockFetchingListener listener = mock(BlockFetchingListener.class);
274+
275+
List<? extends Map<String, Object>> interactions = Arrays.asList(
276+
// SaslTimeout will cause a retry. Since b0 fails, we will retry both.
277+
ImmutableMap.<String, Object>builder()
278+
.put("b0", new SaslTimeoutException(new TimeoutException()))
279+
.build(),
280+
ImmutableMap.<String, Object>builder()
281+
.put("b0", block0)
282+
.build()
283+
);
284+
configMap.put("spark.shuffle.sasl.enableRetries", "true");
285+
performInteractions(interactions, listener);
286+
287+
verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
288+
verify(listener).getTransferType();
289+
verifyNoMoreInteractions(listener);
290+
assert(_retryingBlockTransferor.getRetryCount() == 0);
291+
}
292+
293+
@Test
294+
public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException {
295+
BlockFetchingListener listener = mock(BlockFetchingListener.class);
296+
TimeoutException timeoutException = new TimeoutException();
297+
SaslTimeoutException saslTimeoutException =
298+
new SaslTimeoutException(timeoutException);
299+
List<ImmutableMap<String, Object>> interactions = new ArrayList<>();
300+
for (int i = 0; i < 3; i++) {
301+
interactions.add(
302+
ImmutableMap.<String, Object>builder()
303+
.put("b0", saslTimeoutException)
304+
.build()
305+
);
306+
}
307+
configMap.put("spark.shuffle.sasl.enableRetries", "true");
308+
performInteractions(interactions, listener);
309+
verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
310+
verify(listener, times(3)).getTransferType();
311+
verifyNoMoreInteractions(listener);
312+
assert(_retryingBlockTransferor.getRetryCount() == 2);
313+
}
314+
315+
@Test
316+
public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException {
317+
BlockFetchingListener listener = mock(BlockFetchingListener.class);
318+
319+
List<? extends Map<String, Object>> interactions = Arrays.asList(
320+
ImmutableMap.<String, Object>builder()
321+
.put("b0", new SaslTimeoutException(new TimeoutException()))
322+
.put("b1", new IOException())
323+
.build(),
324+
ImmutableMap.<String, Object>builder()
325+
.put("b0", block0)
326+
.put("b1", new IOException())
327+
.build(),
328+
ImmutableMap.<String, Object>builder()
329+
.put("b1", block1)
330+
.build()
331+
);
332+
configMap.put("spark.shuffle.sasl.enableRetries", "true");
333+
performInteractions(interactions, listener);
334+
verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
335+
verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1);
336+
verify(listener, atLeastOnce()).getTransferType();
337+
verifyNoMoreInteractions(listener);
338+
// This should be equal to 1 because after the SASL exception is retried,
339+
// retryCount should be set back to 0. Then after that b1 encounters an
340+
// exception that is retried.
341+
assert(_retryingBlockTransferor.getRetryCount() == 1);
342+
}
343+
233344
/**
234345
* Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
235346
* Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
@@ -245,9 +356,7 @@ private static void performInteractions(List<? extends Map<String, Object>> inte
245356
BlockFetchingListener listener)
246357
throws IOException, InterruptedException {
247358

248-
MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of(
249-
"spark.shuffle.io.maxRetries", "2",
250-
"spark.shuffle.io.retryWait", "0"));
359+
MapConfigProvider provider = new MapConfigProvider(configMap);
251360
TransportConf conf = new TransportConf("shuffle", provider);
252361
BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class);
253362

@@ -299,6 +408,8 @@ private static void performInteractions(List<? extends Map<String, Object>> inte
299408
assertNotNull(stub);
300409
stub.when(fetchStarter).createAndStart(any(), any());
301410
String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
302-
new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start();
411+
_retryingBlockTransferor =
412+
new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener);
413+
_retryingBlockTransferor.start();
303414
}
304415
}

0 commit comments

Comments
 (0)