diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294..69baaca8a261 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeoutException; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) { SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + ByteBuffer response; buf.writeBytes(msg.body().nioByteBuffer()); - - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + try { + response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + } catch (RuntimeException ex) { + // We know it is a Sasl timeout here if it is a TimeoutException. + if (ex.getCause() instanceof TimeoutException) { + throw new SaslTimeoutException(ex.getCause()); + } else { + throw ex; + } + } payload = saslClient.response(JavaUtils.bufferToArray(response)); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java new file mode 100644 index 000000000000..2533ae93f8de --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * An exception thrown if there is a SASL timeout. + */ +public class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super(cause); + } + + public SaslTimeoutException(String message) { + super(message); + } + + public SaslTimeoutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index f2848c2d4c9a..bbfb99168da2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -333,6 +333,13 @@ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } + /** Whether to enable sasl retries or not. The number of retries is dictated by the config + * `spark.shuffle.io.maxRetries`. + */ + public boolean enableSaslRetries() { + return conf.getBoolean("spark.shuffle.sasl.enableRetries", false); + } + /** * Class name of the implementation of MergedShuffleFileManager that merges the blocks * pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 463edc770d28..4515e3a5c282 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -24,12 +24,14 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -85,6 +87,8 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) /** Number of times we've attempted to retry so far. */ private int retryCount = 0; + private boolean saslTimeoutSeen; + /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, @@ -99,6 +103,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) */ private RetryingBlockTransferListener currentListener; + /** Whether sasl retries are enabled. */ + private final boolean enableSaslRetries; + private final ErrorHandler errorHandler; public RetryingBlockTransferor( @@ -115,6 +122,8 @@ public RetryingBlockTransferor( Collections.addAll(outstandingBlocksIds, blockIds); this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; + this.enableSaslRetries = conf.enableSaslRetries(); + this.saslTimeoutSeen = false; } public RetryingBlockTransferor( @@ -187,13 +196,29 @@ private synchronized void initiateRetry() { /** * Returns true if we should retry due a block transfer failure. We will retry if and only if - * the exception was an IOException and we haven't retried 'maxRetries' times already. + * the exception was an IOException or SaslTimeoutException and we haven't retried + * 'maxRetries' times already. */ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; + boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; + if (!isSaslTimeout && saslTimeoutSeen) { + retryCount = 0; + saslTimeoutSeen = false; + } boolean hasRemainingRetries = retryCount < maxRetries; - return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e); + boolean shouldRetry = (isSaslTimeout || isIOException) && + hasRemainingRetries && errorHandler.shouldRetryError(e); + if (shouldRetry && isSaslTimeout) { + this.saslTimeoutSeen = true; + } + return shouldRetry; + } + + @VisibleForTesting + public int getRetryCount() { + return retryCount; } /** @@ -211,6 +236,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; + if (saslTimeoutSeen) { + retryCount = 0; + saslTimeoutSeen = false; + } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 985a7a364282..a33a471fb7af 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -20,13 +20,18 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeoutException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; + +import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -38,6 +43,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.sasl.SaslTimeoutException; import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter; /** @@ -49,6 +55,16 @@ public class RetryingBlockTransferorSuite { private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + private static Map configMap; + private static RetryingBlockTransferor _retryingBlockTransferor; + + @Before + public void initMap() { + configMap = new HashMap() {{ + put("spark.shuffle.io.maxRetries", "2"); + put("spark.shuffle.io.retryWait", "0"); + }}; + } @Test public void testNoFailures() throws IOException, InterruptedException { @@ -230,6 +246,101 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException verifyNoMoreInteractions(listener); } + @Test + public void testSaslTimeoutFailure() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryOnSaslTimeout() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // SaslTimeout will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new SaslTimeoutException(new TimeoutException())) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 0); + } + + @Test + public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + interactions.add( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build() + ); + } + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener, times(3)).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 2); + } + + @Test + public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", new SaslTimeoutException(new TimeoutException())) + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1); + verify(listener, atLeastOnce()).getTransferType(); + verifyNoMoreInteractions(listener); + // This should be equal to 1 because after the SASL exception is retried, + // retryCount should be set back to 0. Then after that b1 encounters an + // exception that is retried. + assert(_retryingBlockTransferor.getRetryCount() == 1); + } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction @@ -244,9 +355,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException, InterruptedException { - MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( - "spark.shuffle.io.maxRetries", "2", - "spark.shuffle.io.retryWait", "0")); + MapConfigProvider provider = new MapConfigProvider(configMap); TransportConf conf = new TransportConf("shuffle", provider); BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class); @@ -298,6 +407,8 @@ private static void performInteractions(List> inte assertNotNull(stub); stub.when(fetchStarter).createAndStart(any(), any()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); - new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start(); + _retryingBlockTransferor = + new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener); + _retryingBlockTransferor.start(); } }