From 6b130c6e389a397546efd0dd6c6b9ef38633cb09 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Sat, 14 Jan 2023 23:58:56 -0600 Subject: [PATCH] [SPARK-41415] SASL Request Retries ### What changes were proposed in this pull request? Add the ability to retry SASL requests. Will add it as a metric too soon to track SASL retries. ### Why are the changes needed? We are seeing increased SASL timeouts internally, and this issue would mitigate the issue. We already have this feature enabled for our 2.3 jobs, and we have seen failures significantly decrease. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests, and tested on cluster to ensure the retries are being triggered correctly. Closes #38959 from akpatnam25/SPARK-41415. Authored-by: Aravind Patnam Signed-off-by: Mridul Muralidharan gmail.com> --- .../network/sasl/SaslClientBootstrap.java | 14 ++- .../network/sasl/SaslTimeoutException.java | 35 ++++++ .../spark/network/util/TransportConf.java | 7 ++ .../shuffle/RetryingBlockTransferor.java | 33 ++++- .../shuffle/RetryingBlockTransferorSuite.java | 119 +++++++++++++++++- 5 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294..69baaca8a261 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeoutException; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) { SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + ByteBuffer response; buf.writeBytes(msg.body().nioByteBuffer()); - - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + try { + response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + } catch (RuntimeException ex) { + // We know it is a Sasl timeout here if it is a TimeoutException. + if (ex.getCause() instanceof TimeoutException) { + throw new SaslTimeoutException(ex.getCause()); + } else { + throw ex; + } + } payload = saslClient.response(JavaUtils.bufferToArray(response)); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java new file mode 100644 index 000000000000..2533ae93f8de --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * An exception thrown if there is a SASL timeout. + */ +public class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super(cause); + } + + public SaslTimeoutException(String message) { + super(message); + } + + public SaslTimeoutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index f73e3ce2e0aa..9dedd5d9849c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -374,6 +374,13 @@ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } + /** Whether to enable sasl retries or not. The number of retries is dictated by the config + * `spark.shuffle.io.maxRetries`. + */ + public boolean enableSaslRetries() { + return conf.getBoolean("spark.shuffle.sasl.enableRetries", false); + } + /** * Class name of the implementation of MergedShuffleFileManager that merges the blocks * pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 463edc770d28..4515e3a5c282 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -24,12 +24,14 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -85,6 +87,8 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) /** Number of times we've attempted to retry so far. */ private int retryCount = 0; + private boolean saslTimeoutSeen; + /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, @@ -99,6 +103,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) */ private RetryingBlockTransferListener currentListener; + /** Whether sasl retries are enabled. */ + private final boolean enableSaslRetries; + private final ErrorHandler errorHandler; public RetryingBlockTransferor( @@ -115,6 +122,8 @@ public RetryingBlockTransferor( Collections.addAll(outstandingBlocksIds, blockIds); this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; + this.enableSaslRetries = conf.enableSaslRetries(); + this.saslTimeoutSeen = false; } public RetryingBlockTransferor( @@ -187,13 +196,29 @@ private synchronized void initiateRetry() { /** * Returns true if we should retry due a block transfer failure. We will retry if and only if - * the exception was an IOException and we haven't retried 'maxRetries' times already. + * the exception was an IOException or SaslTimeoutException and we haven't retried + * 'maxRetries' times already. */ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; + boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; + if (!isSaslTimeout && saslTimeoutSeen) { + retryCount = 0; + saslTimeoutSeen = false; + } boolean hasRemainingRetries = retryCount < maxRetries; - return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e); + boolean shouldRetry = (isSaslTimeout || isIOException) && + hasRemainingRetries && errorHandler.shouldRetryError(e); + if (shouldRetry && isSaslTimeout) { + this.saslTimeoutSeen = true; + } + return shouldRetry; + } + + @VisibleForTesting + public int getRetryCount() { + return retryCount; } /** @@ -211,6 +236,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; + if (saslTimeoutSeen) { + retryCount = 0; + saslTimeoutSeen = false; + } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 985a7a364282..a33a471fb7af 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -20,13 +20,18 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeoutException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; + +import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -38,6 +43,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.sasl.SaslTimeoutException; import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter; /** @@ -49,6 +55,16 @@ public class RetryingBlockTransferorSuite { private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + private static Map configMap; + private static RetryingBlockTransferor _retryingBlockTransferor; + + @Before + public void initMap() { + configMap = new HashMap() {{ + put("spark.shuffle.io.maxRetries", "2"); + put("spark.shuffle.io.retryWait", "0"); + }}; + } @Test public void testNoFailures() throws IOException, InterruptedException { @@ -230,6 +246,101 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException verifyNoMoreInteractions(listener); } + @Test + public void testSaslTimeoutFailure() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryOnSaslTimeout() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // SaslTimeout will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new SaslTimeoutException(new TimeoutException())) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 0); + } + + @Test + public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + interactions.add( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build() + ); + } + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener, times(3)).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 2); + } + + @Test + public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", new SaslTimeoutException(new TimeoutException())) + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1); + verify(listener, atLeastOnce()).getTransferType(); + verifyNoMoreInteractions(listener); + // This should be equal to 1 because after the SASL exception is retried, + // retryCount should be set back to 0. Then after that b1 encounters an + // exception that is retried. + assert(_retryingBlockTransferor.getRetryCount() == 1); + } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction @@ -244,9 +355,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException, InterruptedException { - MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( - "spark.shuffle.io.maxRetries", "2", - "spark.shuffle.io.retryWait", "0")); + MapConfigProvider provider = new MapConfigProvider(configMap); TransportConf conf = new TransportConf("shuffle", provider); BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class); @@ -298,6 +407,8 @@ private static void performInteractions(List> inte assertNotNull(stub); stub.when(fetchStarter).createAndStart(any(), any()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); - new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start(); + _retryingBlockTransferor = + new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener); + _retryingBlockTransferor.start(); } }