From 4a48e4954da4ba8e5199bddad5fe4f0758c8de99 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 6 Dec 2022 21:52:32 -0800 Subject: [PATCH 01/28] SPARK-41415: SASL Request Retries --- .../spark/network/client/TransportClient.java | 23 +++++- .../network/sasl/SaslClientBootstrap.java | 10 ++- .../spark/network/sasl/SaslInitMessage.java | 50 +++++++++++++ .../spark/network/sasl/SaslMessage.java | 9 ++- .../spark/network/sasl/SaslRpcHandler.java | 35 +++++++-- .../spark/network/util/TransportConf.java | 9 +++ .../spark/network/sasl/SparkSaslSuite.java | 72 +++++++++++++++++++ .../shuffle/RetryingBlockTransferor.java | 10 ++- .../shuffle/RetryingBlockTransferorSuite.java | 27 ++++++- 9 files changed, 232 insertions(+), 13 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index dd2fdb08ee5b..2035bd819bc3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -24,6 +24,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; @@ -115,6 +116,14 @@ public void setClientId(String id) { this.clientId = id; } + /** + * This is needed when sasl server is reset. Sasl authentication + * will be re-attempted. + */ + public void unsetClientId() { + this.clientId = null; + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * @@ -264,7 +273,7 @@ public long uploadStream( public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { final SettableFuture result = SettableFuture.create(); - sendRpc(message, new RpcResponseCallback() { + long rpcId = sendRpc(message, new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { try { @@ -287,6 +296,9 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + logger.warn("RPC {} timed-out", rpcId); + throw Throwables.propagate(new SaslTimeoutException(e)); } catch (ExecutionException e) { throw Throwables.propagate(e.getCause()); } catch (Exception e) { @@ -338,6 +350,15 @@ public String toString() { .toString(); } + /** + * Exception thrown when sasl request times out. + */ + public static class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super((cause)); + } + } + private static long requestId() { return Math.abs(UUID.randomUUID().getLeastSignificantBits()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294..b21f74b2f0ad 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -60,9 +60,15 @@ public void doBootstrap(TransportClient client, Channel channel) { SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); try { byte[] payload = saslClient.firstToken(); - + boolean firstToken = true; while (!saslClient.isComplete()) { - SaslMessage msg = new SaslMessage(appId, payload); + SaslMessage msg; + if (conf.enableSaslRetries() && firstToken) { + msg = new SaslInitMessage(appId, payload); + } else { + msg = new SaslMessage(appId, payload); + } + firstToken = false; ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); buf.writeBytes(msg.body().nioByteBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java new file mode 100644 index 000000000000..eade4401717c --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Encodes the first message in Sasl exchange. When it is retried, the + * SaslServer needs to be reset. {@link SaslRpcHandler} uses the type of this message + * to reset the SaslServer. + */ +public final class SaslInitMessage extends SaslMessage { + + /** Serialization tag used to catch incorrect payloads. */ + static final byte TAG_BYTE = (byte) 0xEB; + + SaslInitMessage(String appId, byte[] message) { + super(appId, message); + } + + SaslInitMessage(String appId, ByteBuf message) { + super(appId, message); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + +} \ No newline at end of file diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 1b03300d948e..df3b58772b77 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -33,7 +33,7 @@ class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ - private static final byte TAG_BYTE = (byte) 0xEA; + static final byte TAG_BYTE = (byte) 0xEA; public final String appId; @@ -76,4 +76,11 @@ public static SaslMessage decode(ByteBuf buf) { buf.readInt(); return new SaslMessage(appId, buf.retain()); } + + public static SaslMessage decodeWithoutTag(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index cc9e88fcf98e..b329939e3258 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -74,10 +74,33 @@ public boolean doAuthChallenge( ByteBuffer message, RpcResponseCallback callback) { if (saslServer == null || !saslServer.isComplete()) { + // save the position and limit, before reading a byte + int position = message.position(); + int limit = message.limit(); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + byte tagByte = nettyBuf.readByte(); + if (super.isAuthenticated() && tagByte != SaslMessage.TAG_BYTE + && tagByte != SaslInitMessage.TAG_BYTE) { + // not a sasl or sasl reset, so sasl completed at client as well. + message.position(position); + message.limit(limit); + return true; + } + if (tagByte != SaslMessage.TAG_BYTE && tagByte != SaslInitMessage.TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); + } + if (tagByte == SaslInitMessage.TAG_BYTE) { + logger.debug("Received an init message for channel {}", client); + if (saslServer != null) { + resetSaslServer(true); + } + client.unsetClientId(); + } SaslMessage saslMessage; try { - saslMessage = SaslMessage.decode(nettyBuf); + saslMessage = SaslMessage.decodeWithoutTag(nettyBuf); } finally { nettyBuf.release(); } @@ -86,13 +109,13 @@ public boolean doAuthChallenge( // First message in the handshake, setup the necessary state. client.setClientId(saslMessage.appId); saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); + conf.saslServerAlwaysEncrypt()); } byte[] response; try { response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); + saslMessage.body().nioByteBuffer())); } catch (IOException ioe) { throw new RuntimeException(ioe); } @@ -107,13 +130,13 @@ public boolean doAuthChallenge( if (saslServer.isComplete()) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { logger.debug("SASL authentication successful for channel {}", client); - complete(true); + resetSaslServer(true); return true; } logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - complete(false); + resetSaslServer(false); return true; } return false; @@ -130,7 +153,7 @@ public void channelInactive(TransportClient client) { } } - private void complete(boolean dispose) { + private void resetSaslServer(boolean dispose) { if (dispose) { try { saslServer.dispose(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index f2848c2d4c9a..57721df61145 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -333,6 +333,15 @@ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } + /** Whether to enable sasl retries. Sasl retries will be enabled, once the shuffle + * server is upgraded. The updated SaslHandler can handle older clients that don't + * send any SaslInitMessage. However, the older SaslHandler will not be able to handle + * SaslInitMessage. + */ + public boolean enableSaslRetries() { + return conf.getBoolean("spark.shuffle.sasl.enableRetries", false); + } + /** * Class name of the implementation of MergedShuffleFileManager that merges the blocks * pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6096cd32f3d0..bc790af06cb8 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -345,6 +345,78 @@ public void testDelegates() throws Exception { } } + @Test + public void testSaslWithRetriesEnabled() throws Exception { + testSaslResetHandling(1); + } + + @Test + public void testMultipleSaslRetries() throws Exception { + testSaslResetHandling(2); + } + + private void testSaslResetHandling(final int maxRetries) { + Map testConf = ImmutableMap.of("spark.authenticate.enableSaslEncryption", + String.valueOf(false), + "spark.shuffle.sasl.enableRetries", String.valueOf(true)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); + + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); + when(keyHolder.getSaslUser(anyString())).thenReturn("user"); + when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); + + Channel channel = mock(Channel.class); + RpcHandler delegate = mock(RpcHandler.class); + doAnswer(invocation -> { + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); + return null; + }).when(delegate) + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); + + RpcHandler saslHandler = new SaslRpcHandler(conf, channel, delegate, keyHolder); + + TransportClient client = mock(TransportClient.class); + final ByteBuffer[] handlerResponse = new ByteBuffer[1]; + + when(client.sendRpcSync(any(), anyLong())).thenAnswer(invocation -> { + + ByteBuffer msg = (ByteBuffer)(invocation.getArguments()[0]); + saslHandler.receive(client, msg, new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + handlerResponse[0] = response; + } + + @Override + public void onFailure(Throwable e) { + } + }); + return handlerResponse[0]; + }); + + SaslClientBootstrap bootstrapClient = new SaslClientBootstrap(conf, "user", keyHolder); + for (int i = 0; i < maxRetries; i++) { + bootstrapClient.doBootstrap(client, channel); + } + + // Subsequent messages to handler should be forwarded to delegate + saslHandler.receive(client, JavaUtils.stringToBytes("Ping"), new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + handlerResponse[0] = response; + } + + @Override + public void onFailure(Throwable e) { + } + }); + + assertEquals("Pong", JavaUtils.bytesToString(handlerResponse[0])); + } + private static class SaslTestCtx implements AutoCloseable { final TransportClient client; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 463edc770d28..bb5dc33a7c26 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -26,6 +26,7 @@ import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.client.TransportClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -99,6 +100,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) */ private RetryingBlockTransferListener currentListener; + /** Whether sasl retries are enabled. */ + private final boolean enableSaslRetries; + private final ErrorHandler errorHandler; public RetryingBlockTransferor( @@ -115,6 +119,7 @@ public RetryingBlockTransferor( Collections.addAll(outstandingBlocksIds, blockIds); this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; + this.enableSaslRetries = conf.enableSaslRetries(); } public RetryingBlockTransferor( @@ -192,8 +197,11 @@ private synchronized void initiateRetry() { private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; + boolean isSaslTimeout = enableSaslRetries && + (e instanceof TransportClient.SaslTimeoutException || + (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); boolean hasRemainingRetries = retryCount < maxRetries; - return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e); + return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); } /** diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 985a7a364282..9da655d23f45 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -27,6 +27,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import java.util.concurrent.TimeoutException; +import org.apache.spark.network.client.TransportClient; import org.junit.Test; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -230,6 +232,26 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException verifyNoMoreInteractions(listener); } + @Test + public void testRetryOnSaslTimeout() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // SaslTimeout will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new TransportClient.SaslTimeoutException(new TimeoutException())) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verifyNoMoreInteractions(listener); + } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction @@ -245,8 +267,9 @@ private static void performInteractions(List> inte throws IOException, InterruptedException { MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( - "spark.shuffle.io.maxRetries", "2", - "spark.shuffle.io.retryWait", "0")); + "spark.shuffle.io.maxRetries", "2", + "spark.shuffle.io.retryWait", "0", + "spark.shuffle.sasl.enableRetries", "true")); TransportConf conf = new TransportConf("shuffle", provider); BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class); From f3e2ea94fa07fdf47c10f8f92fc6a7842f42b21f Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Wed, 7 Dec 2022 01:34:21 -0800 Subject: [PATCH 02/28] fixed unit test and added metric definition --- .../spark/network/server/AbstractAuthRpcHandler.java | 7 ++++++- .../scala/org/apache/spark/InternalAccumulator.scala | 1 + .../scala/org/apache/spark/executor/TaskMetrics.scala | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 95fde677624f..19dedfecb3f2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -33,6 +33,7 @@ public abstract class AbstractAuthRpcHandler extends RpcHandler { private final RpcHandler delegate; private boolean isAuthenticated; + private boolean isComplete; protected AbstractAuthRpcHandler(RpcHandler delegate) { this.delegate = delegate; @@ -53,10 +54,14 @@ public final void receive( TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - if (isAuthenticated) { + if (isAuthenticated && isComplete) { delegate.receive(client, message, callback); } else { + if (isAuthenticated) { + isComplete = true; + } isAuthenticated = doAuthChallenge(client, message, callback); + } } diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 18b10d23da94..6047fbfeb67a 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -41,6 +41,7 @@ private[spark] object InternalAccumulator { val DISK_BYTES_SPILLED = METRICS_PREFIX + "diskBytesSpilled" val PEAK_EXECUTION_MEMORY = METRICS_PREFIX + "peakExecutionMemory" val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses" + val SASL_REQUEST_RETRIES = METRICS_PREFIX + "saslRequestRetries" val TEST_ACCUM = METRICS_PREFIX + "testAccumulator" // scalastyle:off diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 43742a4d46cb..04073847ea4a 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,6 +56,7 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] + private val _saslRequestRetries = new LongAccumulator /** * Time taken on the executor to deserialize this task. @@ -111,6 +112,11 @@ class TaskMetrics private[spark] () extends Serializable { */ def peakExecutionMemory: Long = _peakExecutionMemory.sum + /** + * The number of SASL requests retried by this task. + */ + def saslRequestRetries: Long = _saslRequestRetries.sum + /** * Storage statuses of any blocks that have been updated as a result of this task. * @@ -126,6 +132,7 @@ class TaskMetrics private[spark] () extends Serializable { _updatedBlockStatuses.value.asScala.toSeq } + // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) @@ -147,6 +154,7 @@ class TaskMetrics private[spark] () extends Serializable { _updatedBlockStatuses.setValue(v) private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v.asJava) + private[spark] def incSaslRequestRetries(v: Long): Unit = _saslRequestRetries.add(v) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -220,6 +228,7 @@ class TaskMetrics private[spark] () extends Serializable { DISK_BYTES_SPILLED -> _diskBytesSpilled, PEAK_EXECUTION_MEMORY -> _peakExecutionMemory, UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses, + SASL_REQUEST_RETRIES -> _saslRequestRetries, shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead, From b87420d8eecbbd452722c0369632185ec38268ec Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 8 Dec 2022 22:18:06 -0800 Subject: [PATCH 03/28] fixed more unit tests --- .../network/sasl/SaslClientBootstrap.java | 2 ++ .../spark/network/sasl/SaslRpcHandler.java | 10 ------- .../server/AbstractAuthRpcHandler.java | 26 ++++++++++++++----- .../spark/network/sasl/SparkSaslSuite.java | 2 +- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index b21f74b2f0ad..b5736111a821 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -74,6 +75,7 @@ public void doBootstrap(TransportClient client, Channel channel) { buf.writeBytes(msg.body().nioByteBuffer()); ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + System.out.println("response: " + Arrays.toString(JavaUtils.bufferToArray(response))); payload = saslClient.response(JavaUtils.bufferToArray(response)); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index b329939e3258..c5e04137da6e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -74,19 +74,9 @@ public boolean doAuthChallenge( ByteBuffer message, RpcResponseCallback callback) { if (saslServer == null || !saslServer.isComplete()) { - // save the position and limit, before reading a byte - int position = message.position(); - int limit = message.limit(); ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); byte tagByte = nettyBuf.readByte(); - if (super.isAuthenticated() && tagByte != SaslMessage.TAG_BYTE - && tagByte != SaslInitMessage.TAG_BYTE) { - // not a sasl or sasl reset, so sasl completed at client as well. - message.position(position); - message.limit(limit); - return true; - } if (tagByte != SaslMessage.TAG_BYTE && tagByte != SaslInitMessage.TAG_BYTE) { throw new IllegalStateException("Expected SaslMessage, received something else" + " (maybe your client does not have SASL enabled?)"); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 19dedfecb3f2..c726252783c1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -17,11 +17,15 @@ package org.apache.spark.network.server; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import java.nio.ByteBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.util.JavaUtils; + /** * RPC Handler which performs authentication, and when it's successful, delegates further @@ -33,7 +37,6 @@ public abstract class AbstractAuthRpcHandler extends RpcHandler { private final RpcHandler delegate; private boolean isAuthenticated; - private boolean isComplete; protected AbstractAuthRpcHandler(RpcHandler delegate) { this.delegate = delegate; @@ -54,14 +57,23 @@ public final void receive( TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - if (isAuthenticated && isComplete) { - delegate.receive(client, message, callback); - } else { - if (isAuthenticated) { - isComplete = true; + System.out.println(JavaUtils.bytesToString(message)); + + int position = message.position(); + int limit = message.limit(); + + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + byte tagByte = nettyBuf.readByte(); + if (isAuthenticated) { + if (tagByte != (byte) 0xEA && tagByte != (byte) 0xEB) { + message.position(position); + message.limit(limit); + delegate.receive(client, message, callback); + } else { + isAuthenticated = doAuthChallenge(client, message, callback); } + } else { isAuthenticated = doAuthChallenge(client, message, callback); - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index bc790af06cb8..9533bcbab538 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -352,7 +352,7 @@ public void testSaslWithRetriesEnabled() throws Exception { @Test public void testMultipleSaslRetries() throws Exception { - testSaslResetHandling(2); + testSaslResetHandling(5); } private void testSaslResetHandling(final int maxRetries) { From 1816d0d101312737e84c5efab6f0e3271068bec1 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 8 Dec 2022 22:32:06 -0800 Subject: [PATCH 04/28] fixed some spaces --- .../main/java/org/apache/spark/network/sasl/SaslRpcHandler.java | 1 - .../org/apache/spark/network/server/AbstractAuthRpcHandler.java | 2 -- .../test/java/org/apache/spark/network/sasl/SparkSaslSuite.java | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c5e04137da6e..d32f8a5d4afe 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -74,7 +74,6 @@ public boolean doAuthChallenge( ByteBuffer message, RpcResponseCallback callback) { if (saslServer == null || !saslServer.isComplete()) { - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); byte tagByte = nettyBuf.readByte(); if (tagByte != SaslMessage.TAG_BYTE && tagByte != SaslInitMessage.TAG_BYTE) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index c726252783c1..c6c450907e46 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -57,8 +57,6 @@ public final void receive( TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - System.out.println(JavaUtils.bytesToString(message)); - int position = message.position(); int limit = message.limit(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 9533bcbab538..bc790af06cb8 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -352,7 +352,7 @@ public void testSaslWithRetriesEnabled() throws Exception { @Test public void testMultipleSaslRetries() throws Exception { - testSaslResetHandling(5); + testSaslResetHandling(2); } private void testSaslResetHandling(final int maxRetries) { From 6302cb15484917dc8745f85f76d691362eebbb59 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 9 Dec 2022 14:10:40 -0800 Subject: [PATCH 05/28] fixed one last unit test and linter changes --- .../java/org/apache/spark/network/sasl/SaslInitMessage.java | 3 +-- .../apache/spark/network/server/AbstractAuthRpcHandler.java | 1 - .../apache/spark/network/shuffle/RetryingBlockTransferor.java | 3 ++- .../spark/network/shuffle/RetryingBlockTransferorSuite.java | 3 ++- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java index eade4401717c..b300aa69754e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java @@ -46,5 +46,4 @@ public void encode(ByteBuf buf) { // See comment in encodedLength(). buf.writeInt((int) body().size()); } - -} \ No newline at end of file +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index c6c450907e46..05333b86349c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -24,7 +24,6 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.util.JavaUtils; /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index bb5dc33a7c26..7e23d9aa8be9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -201,7 +201,8 @@ private synchronized boolean shouldRetry(Throwable e) { (e instanceof TransportClient.SaslTimeoutException || (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); boolean hasRemainingRetries = retryCount < maxRetries; - return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); + return (isSaslTimeout || isIOException) && + hasRemainingRetries && errorHandler.shouldRetryError(e); } /** diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 9da655d23f45..f74802c219bb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -248,7 +248,8 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener).getTransferType(); verifyNoMoreInteractions(listener); } From d0487fd0bcd7f885a93605bc7cc5e218d83813c6 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 9 Dec 2022 14:22:17 -0800 Subject: [PATCH 06/28] cleaned up some stuff --- .../apache/spark/network/sasl/SaslClientBootstrap.java | 1 - .../org/apache/spark/network/sasl/SaslInitMessage.java | 2 +- .../java/org/apache/spark/network/sasl/SaslMessage.java | 4 ++-- .../org/apache/spark/network/sasl/SaslRpcHandler.java | 8 ++++---- .../spark/network/server/AbstractAuthRpcHandler.java | 4 +++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index b5736111a821..9cc0dd4c561f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -75,7 +75,6 @@ public void doBootstrap(TransportClient client, Channel channel) { buf.writeBytes(msg.body().nioByteBuffer()); ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); - System.out.println("response: " + Arrays.toString(JavaUtils.bufferToArray(response))); payload = saslClient.response(JavaUtils.bufferToArray(response)); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java index b300aa69754e..6fc387d345fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java @@ -29,7 +29,7 @@ public final class SaslInitMessage extends SaslMessage { /** Serialization tag used to catch incorrect payloads. */ - static final byte TAG_BYTE = (byte) 0xEB; + public static final byte TAG_BYTE = (byte) 0xEB; SaslInitMessage(String appId, byte[] message) { super(appId, message); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index df3b58772b77..19bce89994a6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -30,10 +30,10 @@ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -class SaslMessage extends AbstractMessage { +public class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ - static final byte TAG_BYTE = (byte) 0xEA; + public static final byte TAG_BYTE = (byte) 0xEA; public final String appId; diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index d32f8a5d4afe..f236caab3ac0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -83,7 +83,7 @@ public boolean doAuthChallenge( if (tagByte == SaslInitMessage.TAG_BYTE) { logger.debug("Received an init message for channel {}", client); if (saslServer != null) { - resetSaslServer(true); + complete(true); } client.unsetClientId(); } @@ -119,13 +119,13 @@ public boolean doAuthChallenge( if (saslServer.isComplete()) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { logger.debug("SASL authentication successful for channel {}", client); - resetSaslServer(true); + complete(true); return true; } logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - resetSaslServer(false); + complete(false); return true; } return false; @@ -142,7 +142,7 @@ public void channelInactive(TransportClient client) { } } - private void resetSaslServer(boolean dispose) { + private void complete(boolean dispose) { if (dispose) { try { saslServer.dispose(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 05333b86349c..eccffe2ea255 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -24,6 +24,8 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SaslInitMessage; +import org.apache.spark.network.sasl.SaslMessage; /** @@ -62,7 +64,7 @@ public final void receive( ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); byte tagByte = nettyBuf.readByte(); if (isAuthenticated) { - if (tagByte != (byte) 0xEA && tagByte != (byte) 0xEB) { + if (tagByte != SaslMessage.TAG_BYTE && tagByte != SaslInitMessage.TAG_BYTE) { message.position(position); message.limit(limit); delegate.receive(client, message, callback); From ae502626a80cb60cb8ef7e4c4defa12bb0e1c6a8 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Mon, 12 Dec 2022 15:51:46 -0800 Subject: [PATCH 07/28] added timing code --- .../shuffle/BlockFetchingListener.java | 3 ++ .../shuffle/BlockTransferListener.java | 2 ++ .../shuffle/RetryingBlockTransferor.java | 9 +++++ .../shuffle/RetryingBlockTransferorSuite.java | 1 + .../spark/shuffle/ShuffleBlockPusher.scala | 6 +++- .../storage/ShuffleBlockFetcherIterator.scala | 4 +++ .../apache/spark/util/JsonProtocolSuite.scala | 35 +++++++++++-------- 7 files changed, 45 insertions(+), 15 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 0be913e4d8d9..bdaeece685ae 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -46,4 +46,7 @@ default void onBlockTransferFailure(String blockId, Throwable exception) { default String getTransferType() { return "fetch"; } + + @Override + default void onSaslTimeout() {} } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java index e019dabcba41..418f9bf56bbd 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java @@ -41,4 +41,6 @@ public interface BlockTransferListener extends EventListener { * Return a string indicating the type of the listener such as fetch, push, or something else */ String getTransferType(); + + void onSaslTimeout(); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 7e23d9aa8be9..06c0d63b1c66 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -200,6 +200,9 @@ private synchronized boolean shouldRetry(Throwable e) { boolean isSaslTimeout = enableSaslRetries && (e instanceof TransportClient.SaslTimeoutException || (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); + if (isSaslTimeout) { + listener.onSaslTimeout(); + } boolean hasRemainingRetries = retryCount < maxRetries; return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); @@ -298,5 +301,11 @@ public String getTransferType() { throw new RuntimeException( "Invocation on RetryingBlockTransferListener.getTransferType is unexpected."); } + + @Override + public void onSaslTimeout() { + throw new RuntimeException( + "Invocation on RetryingBlockTransferListener.onSaslTimeout is unexpected."); + } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index f74802c219bb..bb096afff631 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -250,6 +250,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); verify(listener).getTransferType(); + verify(listener).onSaslTimeout(); verifyNoMoreInteractions(listener); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index ac43ba8b56fc..58028d88b83c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -25,7 +25,7 @@ import java.util.concurrent.ExecutorService import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.util.control.NonFatal -import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv, TaskContext} import org.apache.spark.annotation.Since import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend} import org.apache.spark.internal.Logging @@ -251,6 +251,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { } handleResult(PushResult(blockId, exception)) } + + override def onSaslTimeout(): Unit = { + TaskContext.get().taskMetrics().incSaslRequestRetries(1) + } } // In addition to randomizing the order of the push requests, further randomize the order // of blocks within the push request to further reduce the likelihood of shuffle server side diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e35144756b59..c3bc3e189723 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -351,6 +351,10 @@ final class ShuffleBlockFetcherIterator( } } } + + override def onSaslTimeout(): Unit = { + context.taskMetrics().incSaslRequestRetries(1) + } } // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 7f93051680ca..b092ce377f32 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -2567,104 +2567,111 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 11, - | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", + | "Name": "$SASL_REQUEST_RETRIES", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 12, - | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", + | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 13, - | "Name": "${shuffleRead.REMOTE_BYTES_READ}", + | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 14, - | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", + | "Name": "${shuffleRead.REMOTE_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 15, - | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 16, - | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 17, - | "Name": "${shuffleRead.RECORDS_READ}", + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 18, - | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 19, - | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 20, - | "Name": "${shuffleWrite.WRITE_TIME}", + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 21, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 22, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 23, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 23, + | "ID": 24, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 24, + | "ID": 25, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 25, + | "ID": 26, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, From dc163f29aaf8c54d3d1f48777c90d53fe290ba67 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Mon, 12 Dec 2022 23:01:04 -0800 Subject: [PATCH 08/28] fix some more stuff --- .../org/apache/spark/network/sasl/SaslClientBootstrap.java | 1 - .../java/org/apache/spark/network/sasl/SaslRpcHandler.java | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 9cc0dd4c561f..b21f74b2f0ad 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Arrays; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index f236caab3ac0..6d273ed2620c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -98,13 +98,13 @@ public boolean doAuthChallenge( // First message in the handshake, setup the necessary state. client.setClientId(saslMessage.appId); saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); + conf.saslServerAlwaysEncrypt()); } byte[] response; try { response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); + saslMessage.body().nioByteBuffer())); } catch (IOException ioe) { throw new RuntimeException(ioe); } From d6bf891ff8bd9f058c5374cc87129801dc2b0bee Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Wed, 4 Jan 2023 15:34:09 -0800 Subject: [PATCH 09/28] remove protocol/server side changes --- .../spark/network/client/TransportClient.java | 8 --- .../network/sasl/SaslClientBootstrap.java | 9 +-- .../spark/network/sasl/SaslInitMessage.java | 49 ------------- .../spark/network/sasl/SaslMessage.java | 11 +-- .../spark/network/sasl/SaslRpcHandler.java | 14 +--- .../server/AbstractAuthRpcHandler.java | 17 +---- .../spark/network/sasl/SparkSaslSuite.java | 72 ------------------- 7 files changed, 5 insertions(+), 175 deletions(-) delete mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 2035bd819bc3..c82433a79610 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -116,14 +116,6 @@ public void setClientId(String id) { this.clientId = id; } - /** - * This is needed when sasl server is reset. Sasl authentication - * will be re-attempted. - */ - public void unsetClientId() { - this.clientId = null; - } - /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index b21f74b2f0ad..92b946c6b13a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -60,15 +60,8 @@ public void doBootstrap(TransportClient client, Channel channel) { SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); try { byte[] payload = saslClient.firstToken(); - boolean firstToken = true; while (!saslClient.isComplete()) { - SaslMessage msg; - if (conf.enableSaslRetries() && firstToken) { - msg = new SaslInitMessage(appId, payload); - } else { - msg = new SaslMessage(appId, payload); - } - firstToken = false; + SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); buf.writeBytes(msg.body().nioByteBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java deleted file mode 100644 index 6fc387d345fd..000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslInitMessage.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -/** - * Encodes the first message in Sasl exchange. When it is retried, the - * SaslServer needs to be reset. {@link SaslRpcHandler} uses the type of this message - * to reset the SaslServer. - */ -public final class SaslInitMessage extends SaslMessage { - - /** Serialization tag used to catch incorrect payloads. */ - public static final byte TAG_BYTE = (byte) 0xEB; - - SaslInitMessage(String appId, byte[] message) { - super(appId, message); - } - - SaslInitMessage(String appId, ByteBuf message) { - super(appId, message); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - Encoders.Strings.encode(buf, appId); - // See comment in encodedLength(). - buf.writeInt((int) body().size()); - } -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 19bce89994a6..1b03300d948e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -30,10 +30,10 @@ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -public class SaslMessage extends AbstractMessage { +class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ - public static final byte TAG_BYTE = (byte) 0xEA; + private static final byte TAG_BYTE = (byte) 0xEA; public final String appId; @@ -76,11 +76,4 @@ public static SaslMessage decode(ByteBuf buf) { buf.readInt(); return new SaslMessage(appId, buf.retain()); } - - public static SaslMessage decodeWithoutTag(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - // See comment in encodedLength(). - buf.readInt(); - return new SaslMessage(appId, buf.retain()); - } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 6d273ed2620c..cc9e88fcf98e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -75,21 +75,9 @@ public boolean doAuthChallenge( RpcResponseCallback callback) { if (saslServer == null || !saslServer.isComplete()) { ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - byte tagByte = nettyBuf.readByte(); - if (tagByte != SaslMessage.TAG_BYTE && tagByte != SaslInitMessage.TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else" - + " (maybe your client does not have SASL enabled?)"); - } - if (tagByte == SaslInitMessage.TAG_BYTE) { - logger.debug("Received an init message for channel {}", client); - if (saslServer != null) { - complete(true); - } - client.unsetClientId(); - } SaslMessage saslMessage; try { - saslMessage = SaslMessage.decodeWithoutTag(nettyBuf); + saslMessage = SaslMessage.decode(nettyBuf); } finally { nettyBuf.release(); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index eccffe2ea255..9414db4c550b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -17,15 +17,11 @@ package org.apache.spark.network.server; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import java.nio.ByteBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.sasl.SaslInitMessage; -import org.apache.spark.network.sasl.SaslMessage; /** @@ -58,19 +54,8 @@ public final void receive( TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - int position = message.position(); - int limit = message.limit(); - - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - byte tagByte = nettyBuf.readByte(); if (isAuthenticated) { - if (tagByte != SaslMessage.TAG_BYTE && tagByte != SaslInitMessage.TAG_BYTE) { - message.position(position); - message.limit(limit); - delegate.receive(client, message, callback); - } else { - isAuthenticated = doAuthChallenge(client, message, callback); - } + delegate.receive(client, message, callback); } else { isAuthenticated = doAuthChallenge(client, message, callback); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index bc790af06cb8..6096cd32f3d0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -345,78 +345,6 @@ public void testDelegates() throws Exception { } } - @Test - public void testSaslWithRetriesEnabled() throws Exception { - testSaslResetHandling(1); - } - - @Test - public void testMultipleSaslRetries() throws Exception { - testSaslResetHandling(2); - } - - private void testSaslResetHandling(final int maxRetries) { - Map testConf = ImmutableMap.of("spark.authenticate.enableSaslEncryption", - String.valueOf(false), - "spark.shuffle.sasl.enableRetries", String.valueOf(true)); - TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); - - SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); - when(keyHolder.getSaslUser(anyString())).thenReturn("user"); - when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); - - Channel channel = mock(Channel.class); - RpcHandler delegate = mock(RpcHandler.class); - doAnswer(invocation -> { - ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; - RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", JavaUtils.bytesToString(message)); - cb.onSuccess(JavaUtils.stringToBytes("Pong")); - return null; - }).when(delegate) - .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - - RpcHandler saslHandler = new SaslRpcHandler(conf, channel, delegate, keyHolder); - - TransportClient client = mock(TransportClient.class); - final ByteBuffer[] handlerResponse = new ByteBuffer[1]; - - when(client.sendRpcSync(any(), anyLong())).thenAnswer(invocation -> { - - ByteBuffer msg = (ByteBuffer)(invocation.getArguments()[0]); - saslHandler.receive(client, msg, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - handlerResponse[0] = response; - } - - @Override - public void onFailure(Throwable e) { - } - }); - return handlerResponse[0]; - }); - - SaslClientBootstrap bootstrapClient = new SaslClientBootstrap(conf, "user", keyHolder); - for (int i = 0; i < maxRetries; i++) { - bootstrapClient.doBootstrap(client, channel); - } - - // Subsequent messages to handler should be forwarded to delegate - saslHandler.receive(client, JavaUtils.stringToBytes("Ping"), new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - handlerResponse[0] = response; - } - - @Override - public void onFailure(Throwable e) { - } - }); - - assertEquals("Pong", JavaUtils.bytesToString(handlerResponse[0])); - } - private static class SaslTestCtx implements AutoCloseable { final TransportClient client; From d5ed4a375dcc5da0aef3161fbdb2543cc4d7157a Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Wed, 4 Jan 2023 15:37:27 -0800 Subject: [PATCH 10/28] fix some minor issues --- .../org/apache/spark/network/sasl/SaslClientBootstrap.java | 1 + .../apache/spark/network/server/AbstractAuthRpcHandler.java | 1 - .../java/org/apache/spark/network/util/TransportConf.java | 4 +--- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 92b946c6b13a..647813772294 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -60,6 +60,7 @@ public void doBootstrap(TransportClient client, Channel channel) { SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); try { byte[] payload = saslClient.firstToken(); + while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 9414db4c550b..95fde677624f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -23,7 +23,6 @@ import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; - /** * RPC Handler which performs authentication, and when it's successful, delegates further * calls to another RPC handler. The authentication handshake itself should be implemented diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 57721df61145..0ebaab97351e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -334,9 +334,7 @@ public boolean useOldFetchProtocol() { } /** Whether to enable sasl retries. Sasl retries will be enabled, once the shuffle - * server is upgraded. The updated SaslHandler can handle older clients that don't - * send any SaslInitMessage. However, the older SaslHandler will not be able to handle - * SaslInitMessage. + * server is upgraded. */ public boolean enableSaslRetries() { return conf.getBoolean("spark.shuffle.sasl.enableRetries", false); From a41bfc9b95491dd331bc6037d3bc890be89dad4e Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Wed, 11 Jan 2023 13:46:31 -0800 Subject: [PATCH 11/28] revert SASL metrics changes --- .../shuffle/BlockFetchingListener.java | 3 -- .../shuffle/BlockTransferListener.java | 2 -- .../shuffle/RetryingBlockTransferor.java | 9 ----- .../shuffle/RetryingBlockTransferorSuite.java | 1 - .../apache/spark/InternalAccumulator.scala | 1 - .../apache/spark/executor/TaskMetrics.scala | 8 ----- .../spark/shuffle/ShuffleBlockPusher.scala | 6 +--- .../storage/ShuffleBlockFetcherIterator.scala | 4 --- .../apache/spark/util/JsonProtocolSuite.scala | 35 ++++++++----------- 9 files changed, 15 insertions(+), 54 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index bdaeece685ae..0be913e4d8d9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -46,7 +46,4 @@ default void onBlockTransferFailure(String blockId, Throwable exception) { default String getTransferType() { return "fetch"; } - - @Override - default void onSaslTimeout() {} } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java index 418f9bf56bbd..e019dabcba41 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockTransferListener.java @@ -41,6 +41,4 @@ public interface BlockTransferListener extends EventListener { * Return a string indicating the type of the listener such as fetch, push, or something else */ String getTransferType(); - - void onSaslTimeout(); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 06c0d63b1c66..7e23d9aa8be9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -200,9 +200,6 @@ private synchronized boolean shouldRetry(Throwable e) { boolean isSaslTimeout = enableSaslRetries && (e instanceof TransportClient.SaslTimeoutException || (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); - if (isSaslTimeout) { - listener.onSaslTimeout(); - } boolean hasRemainingRetries = retryCount < maxRetries; return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); @@ -301,11 +298,5 @@ public String getTransferType() { throw new RuntimeException( "Invocation on RetryingBlockTransferListener.getTransferType is unexpected."); } - - @Override - public void onSaslTimeout() { - throw new RuntimeException( - "Invocation on RetryingBlockTransferListener.onSaslTimeout is unexpected."); - } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index bb096afff631..f74802c219bb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -250,7 +250,6 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); verify(listener).getTransferType(); - verify(listener).onSaslTimeout(); verifyNoMoreInteractions(listener); } diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 6047fbfeb67a..18b10d23da94 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -41,7 +41,6 @@ private[spark] object InternalAccumulator { val DISK_BYTES_SPILLED = METRICS_PREFIX + "diskBytesSpilled" val PEAK_EXECUTION_MEMORY = METRICS_PREFIX + "peakExecutionMemory" val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses" - val SASL_REQUEST_RETRIES = METRICS_PREFIX + "saslRequestRetries" val TEST_ACCUM = METRICS_PREFIX + "testAccumulator" // scalastyle:off diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 04073847ea4a..b3e9e715590f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,7 +56,6 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] - private val _saslRequestRetries = new LongAccumulator /** * Time taken on the executor to deserialize this task. @@ -112,11 +111,6 @@ class TaskMetrics private[spark] () extends Serializable { */ def peakExecutionMemory: Long = _peakExecutionMemory.sum - /** - * The number of SASL requests retried by this task. - */ - def saslRequestRetries: Long = _saslRequestRetries.sum - /** * Storage statuses of any blocks that have been updated as a result of this task. * @@ -154,7 +148,6 @@ class TaskMetrics private[spark] () extends Serializable { _updatedBlockStatuses.setValue(v) private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v.asJava) - private[spark] def incSaslRequestRetries(v: Long): Unit = _saslRequestRetries.add(v) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -228,7 +221,6 @@ class TaskMetrics private[spark] () extends Serializable { DISK_BYTES_SPILLED -> _diskBytesSpilled, PEAK_EXECUTION_MEMORY -> _peakExecutionMemory, UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses, - SASL_REQUEST_RETRIES -> _saslRequestRetries, shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead, diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index 58028d88b83c..ac43ba8b56fc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -25,7 +25,7 @@ import java.util.concurrent.ExecutorService import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.util.control.NonFatal -import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv} import org.apache.spark.annotation.Since import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend} import org.apache.spark.internal.Logging @@ -251,10 +251,6 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { } handleResult(PushResult(blockId, exception)) } - - override def onSaslTimeout(): Unit = { - TaskContext.get().taskMetrics().incSaslRequestRetries(1) - } } // In addition to randomizing the order of the push requests, further randomize the order // of blocks within the push request to further reduce the likelihood of shuffle server side diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c3bc3e189723..e35144756b59 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -351,10 +351,6 @@ final class ShuffleBlockFetcherIterator( } } } - - override def onSaslTimeout(): Unit = { - context.taskMetrics().incSaslRequestRetries(1) - } } // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index b092ce377f32..7f93051680ca 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -2567,111 +2567,104 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 11, - | "Name": "$SASL_REQUEST_RETRIES", - | "Update": 0, - | "Internal": true, - | "Count Failed Values": true - | }, - | { - | "ID": 12, | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 13, + | "ID": 12, | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 14, + | "ID": 13, | "Name": "${shuffleRead.REMOTE_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 15, + | "ID": 14, | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 16, + | "ID": 15, | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 17, + | "ID": 16, | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 18, + | "ID": 17, | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 19, + | "ID": 18, | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 20, + | "ID": 19, | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 21, + | "ID": 20, | "Name": "${shuffleWrite.WRITE_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 21, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 23, + | "ID": 22, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 24, + | "ID": 23, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 25, + | "ID": 24, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 26, + | "ID": 25, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, From 035c6651785f807ff07afe0b18401ae15accbe86 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Wed, 11 Jan 2023 13:57:40 -0800 Subject: [PATCH 12/28] fix minor issues --- .../java/org/apache/spark/network/client/TransportClient.java | 2 +- .../java/org/apache/spark/network/util/TransportConf.java | 4 ++-- .../main/scala/org/apache/spark/executor/TaskMetrics.scala | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index c82433a79610..5780d8668206 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -289,7 +289,7 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); } catch (TimeoutException e) { - logger.warn("RPC {} timed-out", rpcId); + logger.trace("RPC {} timed-out", rpcId); throw Throwables.propagate(new SaslTimeoutException(e)); } catch (ExecutionException e) { throw Throwables.propagate(e.getCause()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 0ebaab97351e..2dc36efb3655 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -333,8 +333,8 @@ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } - /** Whether to enable sasl retries. Sasl retries will be enabled, once the shuffle - * server is upgraded. + /** Whether to enable sasl retries or not. The number of retries is given by the config + * `spark.shuffle.io.maxRetries`. */ public boolean enableSaslRetries() { return conf.getBoolean("spark.shuffle.sasl.enableRetries", false); diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index b3e9e715590f..43742a4d46cb 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -126,7 +126,6 @@ class TaskMetrics private[spark] () extends Serializable { _updatedBlockStatuses.value.asScala.toSeq } - // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) From 129375d7ffcc5ee1c5437d6a8e5c3dd2094fa38a Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 12 Jan 2023 00:20:09 -0800 Subject: [PATCH 13/28] addressed all comments --- .../spark/network/client/TransportClient.java | 12 ------ .../network/sasl/SaslClientBootstrap.java | 14 ++++--- .../shuffle/RetryingBlockTransferor.java | 5 ++- .../shuffle/RetryingBlockTransferorSuite.java | 42 +++++++++++++++---- 4 files changed, 47 insertions(+), 26 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 5780d8668206..15793438d599 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -288,9 +288,6 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - logger.trace("RPC {} timed-out", rpcId); - throw Throwables.propagate(new SaslTimeoutException(e)); } catch (ExecutionException e) { throw Throwables.propagate(e.getCause()); } catch (Exception e) { @@ -342,15 +339,6 @@ public String toString() { .toString(); } - /** - * Exception thrown when sasl request times out. - */ - public static class SaslTimeoutException extends RuntimeException { - public SaslTimeoutException(Throwable cause) { - super((cause)); - } - } - private static long requestId() { return Math.abs(UUID.randomUUID().getLeastSignificantBits()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294..78ef571e3743 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeoutException; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -66,9 +67,12 @@ public void doBootstrap(TransportClient client, Channel channel) { ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); buf.writeBytes(msg.body().nioByteBuffer()); - - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); - payload = saslClient.response(JavaUtils.bufferToArray(response)); + try { + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); + } catch (RuntimeException e) { + throw e.getCause(); + } } client.setClientId(appId); @@ -83,8 +87,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for encryption.", client); } - } catch (IOException ioe) { - throw new RuntimeException(ioe); + } catch (Throwable e) { + throw new RuntimeException(e); } finally { if (saslClient != null) { try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 7e23d9aa8be9..9a63180db1c4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -26,6 +26,7 @@ import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; +import java.util.concurrent.TimeoutException; import org.apache.spark.network.client.TransportClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -198,8 +199,8 @@ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; boolean isSaslTimeout = enableSaslRetries && - (e instanceof TransportClient.SaslTimeoutException || - (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); + (e instanceof TimeoutException || + (e.getCause() != null && e.getCause() instanceof TimeoutException)); boolean hasRemainingRetries = retryCount < maxRetries; return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index f74802c219bb..ea0ed66deef7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -28,7 +29,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import java.util.concurrent.TimeoutException; -import org.apache.spark.network.client.TransportClient; +import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -51,6 +52,15 @@ public class RetryingBlockTransferorSuite { private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + private static Map configMap; + + @Before + public void initMap() { + configMap = new HashMap() {{ + put("spark.shuffle.io.maxRetries", "2"); + put("spark.shuffle.io.retryWait", "0"); + }}; + } @Test public void testNoFailures() throws IOException, InterruptedException { @@ -232,6 +242,26 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException verifyNoMoreInteractions(listener); } + @Test + public void testSaslTimeoutFailure() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", timeoutException) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferFailure("b0", timeoutException); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + } + @Test public void testRetryOnSaslTimeout() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); @@ -239,13 +269,13 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { List> interactions = Arrays.asList( // SaslTimeout will cause a retry. Since b0 fails, we will retry both. ImmutableMap.builder() - .put("b0", new TransportClient.SaslTimeoutException(new TimeoutException())) + .put("b0", new TimeoutException()) .build(), ImmutableMap.builder() .put("b0", block0) .build() ); - + configMap.put("spark.shuffle.sasl.enableRetries", "true"); performInteractions(interactions, listener); verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); @@ -253,6 +283,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { verifyNoMoreInteractions(listener); } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction @@ -267,10 +298,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException, InterruptedException { - MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( - "spark.shuffle.io.maxRetries", "2", - "spark.shuffle.io.retryWait", "0", - "spark.shuffle.sasl.enableRetries", "true")); + MapConfigProvider provider = new MapConfigProvider(configMap); TransportConf conf = new TransportConf("shuffle", provider); BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class); From 6b2a3c6d1a8340f2b4d9a5a43ab79bb1b0beceff Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 12 Jan 2023 00:24:11 -0800 Subject: [PATCH 14/28] remove some unneccessary code --- .../java/org/apache/spark/network/client/TransportClient.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 15793438d599..dd2fdb08ee5b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -24,7 +24,6 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; @@ -265,7 +264,7 @@ public long uploadStream( public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { final SettableFuture result = SettableFuture.create(); - long rpcId = sendRpc(message, new RpcResponseCallback() { + sendRpc(message, new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { try { From ce71f6912c5806f7b2b532553bbaf75d3b1d14bd Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 12 Jan 2023 00:25:31 -0800 Subject: [PATCH 15/28] remove some unneccessary code --- .../java/org/apache/spark/network/sasl/SaslClientBootstrap.java | 2 -- .../main/java/org/apache/spark/network/util/TransportConf.java | 2 +- .../apache/spark/network/shuffle/RetryingBlockTransferor.java | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 78ef571e3743..728745829eb8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,9 +17,7 @@ package org.apache.spark.network.sasl; -import java.io.IOException; import java.nio.ByteBuffer; -import java.util.concurrent.TimeoutException; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 2dc36efb3655..bbfb99168da2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -333,7 +333,7 @@ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } - /** Whether to enable sasl retries or not. The number of retries is given by the config + /** Whether to enable sasl retries or not. The number of retries is dictated by the config * `spark.shuffle.io.maxRetries`. */ public boolean enableSaslRetries() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 9a63180db1c4..0afb5af2629f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -27,7 +27,6 @@ import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; import java.util.concurrent.TimeoutException; -import org.apache.spark.network.client.TransportClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; From 273b7ac8a04e321f68af03229f61a8d2418bd4d4 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 12 Jan 2023 11:40:25 -0800 Subject: [PATCH 16/28] fix SASL exception throwing --- .../spark/network/client/TransportClient.java | 9 ++++++++ .../network/sasl/SaslClientBootstrap.java | 21 +++++++++++++------ .../shuffle/RetryingBlockTransferor.java | 6 +++--- .../shuffle/RetryingBlockTransferorSuite.java | 8 ++++--- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index dd2fdb08ee5b..9462cbac82b3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -294,6 +294,15 @@ public void onFailure(Throwable e) { } } + /** + * Exception thrown when sasl request times out. + */ + public static class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super((cause)); + } + } + /** * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the * message, and no delivery guarantees are made. diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 728745829eb8..334a849e83b9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,7 +17,10 @@ package org.apache.spark.network.sasl; +import com.google.common.base.Throwables; +import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeoutException; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -64,13 +67,19 @@ public void doBootstrap(TransportClient client, Channel channel) { SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + ByteBuffer response; buf.writeBytes(msg.body().nioByteBuffer()); try { - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); - payload = saslClient.response(JavaUtils.bufferToArray(response)); - } catch (RuntimeException e) { - throw e.getCause(); + response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + } catch (RuntimeException ex) { + // We know it is a Sasl timeout here if it is a TimeoutException. + if (ex.getCause() instanceof TimeoutException) { + throw Throwables.propagate(new TransportClient.SaslTimeoutException(ex.getCause())); + } else { + throw ex; + } } + payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); @@ -85,8 +94,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for encryption.", client); } - } catch (Throwable e) { - throw new RuntimeException(e); + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 0afb5af2629f..7e23d9aa8be9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -26,7 +26,7 @@ import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; -import java.util.concurrent.TimeoutException; +import org.apache.spark.network.client.TransportClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -198,8 +198,8 @@ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; boolean isSaslTimeout = enableSaslRetries && - (e instanceof TimeoutException || - (e.getCause() != null && e.getCause() instanceof TimeoutException)); + (e instanceof TransportClient.SaslTimeoutException || + (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); boolean hasRemainingRetries = retryCount < maxRetries; return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index ea0ed66deef7..acc29e65639f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import java.util.concurrent.TimeoutException; +import org.apache.spark.network.client.TransportClient; import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; @@ -246,9 +247,10 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException public void testSaslTimeoutFailure() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); TimeoutException timeoutException = new TimeoutException(); + TransportClient.SaslTimeoutException saslTimeoutException = new TransportClient.SaslTimeoutException(timeoutException); List> interactions = Arrays.asList( ImmutableMap.builder() - .put("b0", timeoutException) + .put("b0", saslTimeoutException) .build(), ImmutableMap.builder() .put("b0", block0) @@ -257,7 +259,7 @@ public void testSaslTimeoutFailure() throws IOException, InterruptedException { performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockTransferFailure("b0", timeoutException); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); verify(listener).getTransferType(); verifyNoMoreInteractions(listener); } @@ -269,7 +271,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { List> interactions = Arrays.asList( // SaslTimeout will cause a retry. Since b0 fails, we will retry both. ImmutableMap.builder() - .put("b0", new TimeoutException()) + .put("b0", new TransportClient.SaslTimeoutException(new TimeoutException())) .build(), ImmutableMap.builder() .put("b0", block0) From b91c2d44dfe2be2fc06c3e06c4d53faed4569f81 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 12 Jan 2023 17:02:42 -0800 Subject: [PATCH 17/28] fix linter error --- .../spark/network/shuffle/RetryingBlockTransferorSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index acc29e65639f..815d06ccad18 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -247,7 +247,8 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException public void testSaslTimeoutFailure() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); TimeoutException timeoutException = new TimeoutException(); - TransportClient.SaslTimeoutException saslTimeoutException = new TransportClient.SaslTimeoutException(timeoutException); + TransportClient.SaslTimeoutException saslTimeoutException = + new TransportClient.SaslTimeoutException(timeoutException); List> interactions = Arrays.asList( ImmutableMap.builder() .put("b0", saslTimeoutException) From 3440f1259e7a04ac0f38ece8b1ce69a9253d43b7 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 00:32:58 -0800 Subject: [PATCH 18/28] address latest comments --- .../spark/network/client/TransportClient.java | 9 --------- .../network/sasl/SaslClientBootstrap.java | 3 +-- .../network/sasl/SaslTimeoutException.java | 15 ++++++++++++++ .../shuffle/RetryingBlockTransferor.java | 20 +++++++++++++++---- .../shuffle/RetryingBlockTransferorSuite.java | 8 ++++---- 5 files changed, 36 insertions(+), 19 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 9462cbac82b3..dd2fdb08ee5b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -294,15 +294,6 @@ public void onFailure(Throwable e) { } } - /** - * Exception thrown when sasl request times out. - */ - public static class SaslTimeoutException extends RuntimeException { - public SaslTimeoutException(Throwable cause) { - super((cause)); - } - } - /** * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the * message, and no delivery guarantees are made. diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 334a849e83b9..69baaca8a261 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,7 +17,6 @@ package org.apache.spark.network.sasl; -import com.google.common.base.Throwables; import java.io.IOException; import java.nio.ByteBuffer; import java.util.concurrent.TimeoutException; @@ -74,7 +73,7 @@ public void doBootstrap(TransportClient client, Channel channel) { } catch (RuntimeException ex) { // We know it is a Sasl timeout here if it is a TimeoutException. if (ex.getCause() instanceof TimeoutException) { - throw Throwables.propagate(new TransportClient.SaslTimeoutException(ex.getCause())); + throw new SaslTimeoutException(ex.getCause()); } else { throw ex; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java new file mode 100644 index 000000000000..ecdd764d41af --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java @@ -0,0 +1,15 @@ +package org.apache.spark.network.sasl; + +public class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super(cause); + } + + public SaslTimeoutException(String message) { + super(message); + } + + public SaslTimeoutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 7e23d9aa8be9..d41b7ef1da6b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -19,14 +19,16 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashSet; +import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; -import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -85,6 +87,12 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) // while inside a synchronized block. /** Number of times we've attempted to retry so far. */ private int retryCount = 0; + /** + * Map to track blockId to exception that the block is being retried for. + * This is mainly used in the case of SASL retries, because we need to set + * `retryCount` back to 0 in those cases. + */ + private Map blockIdToException; /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. @@ -120,6 +128,7 @@ public RetryingBlockTransferor( this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; this.enableSaslRetries = conf.enableSaslRetries(); + this.blockIdToException = new HashMap(); } public RetryingBlockTransferor( @@ -197,9 +206,7 @@ private synchronized void initiateRetry() { private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; - boolean isSaslTimeout = enableSaslRetries && - (e instanceof TransportClient.SaslTimeoutException || - (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); + boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; boolean hasRemainingRetries = retryCount < maxRetries; return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); @@ -220,6 +227,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; + if (blockIdToException.containsKey(blockId) && + blockIdToException.get(blockId) instanceof SaslTimeoutException) { + retryCount = 0; + } } } @@ -236,6 +247,7 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) { synchronized (RetryingBlockTransferor.this) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { if (shouldRetry(exception)) { + blockIdToException.putIfAbsent(blockId, exception); initiateRetry(); } else { if (errorHandler.shouldLogError(exception)) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 815d06ccad18..0fea1aef1b8a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -29,7 +29,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import java.util.concurrent.TimeoutException; -import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; @@ -247,8 +247,8 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException public void testSaslTimeoutFailure() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); TimeoutException timeoutException = new TimeoutException(); - TransportClient.SaslTimeoutException saslTimeoutException = - new TransportClient.SaslTimeoutException(timeoutException); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); List> interactions = Arrays.asList( ImmutableMap.builder() .put("b0", saslTimeoutException) @@ -272,7 +272,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { List> interactions = Arrays.asList( // SaslTimeout will cause a retry. Since b0 fails, we will retry both. ImmutableMap.builder() - .put("b0", new TransportClient.SaslTimeoutException(new TimeoutException())) + .put("b0", new SaslTimeoutException(new TimeoutException())) .build(), ImmutableMap.builder() .put("b0", block0) From ad079d605bc3596033db0a5fa08f05b3939ca447 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 10:46:45 -0800 Subject: [PATCH 19/28] add flag for current sasl timeout and added javadoc --- .../network/sasl/SaslTimeoutException.java | 20 +++++++++++++++++ .../shuffle/RetryingBlockTransferor.java | 22 +++++++++---------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java index ecdd764d41af..2533ae93f8de 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.network.sasl; +/** + * An exception thrown if there is a SASL timeout. + */ public class SaslTimeoutException extends RuntimeException { public SaslTimeoutException(Throwable cause) { super(cause); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index d41b7ef1da6b..9f4c74bae7ee 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -21,7 +21,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; -import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -87,12 +86,8 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) // while inside a synchronized block. /** Number of times we've attempted to retry so far. */ private int retryCount = 0; - /** - * Map to track blockId to exception that the block is being retried for. - * This is mainly used in the case of SASL retries, because we need to set - * `retryCount` back to 0 in those cases. - */ - private Map blockIdToException; + + private boolean isCurrentSaslTimeout; /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. @@ -128,7 +123,7 @@ public RetryingBlockTransferor( this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; this.enableSaslRetries = conf.enableSaslRetries(); - this.blockIdToException = new HashMap(); + this.isCurrentSaslTimeout = false; } public RetryingBlockTransferor( @@ -208,8 +203,12 @@ private synchronized boolean shouldRetry(Throwable e) { || e.getCause() instanceof IOException; boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; boolean hasRemainingRetries = retryCount < maxRetries; - return (isSaslTimeout || isIOException) && + boolean shouldRetry = (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); + if (shouldRetry && isSaslTimeout) { + this.isCurrentSaslTimeout = true; + } + return shouldRetry; } /** @@ -227,9 +226,9 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; - if (blockIdToException.containsKey(blockId) && - blockIdToException.get(blockId) instanceof SaslTimeoutException) { + if (isCurrentSaslTimeout) { retryCount = 0; + isCurrentSaslTimeout = false; } } } @@ -247,7 +246,6 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) { synchronized (RetryingBlockTransferor.this) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { if (shouldRetry(exception)) { - blockIdToException.putIfAbsent(blockId, exception); initiateRetry(); } else { if (errorHandler.shouldLogError(exception)) { From b6f04fd95a3859e0163c9950e0a09291159df261 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 11:05:09 -0800 Subject: [PATCH 20/28] fix linter --- .../apache/spark/network/shuffle/RetryingBlockTransferor.java | 1 - 1 file changed, 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 9f4c74bae7ee..95ff7ca32cbd 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashSet; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; From 9402e007521a9d9844727c22c096751dad81f4dc Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 13:19:20 -0800 Subject: [PATCH 21/28] add testRepeatedSaslRetryFailures --- .../shuffle/RetryingBlockTransferorSuite.java | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 0fea1aef1b8a..6b7f5861af23 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashSet; @@ -286,6 +287,27 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { verifyNoMoreInteractions(listener); } + @Test + public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + interactions.add( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build() + ); + } + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, times(3)).getTransferType(); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verifyNoMoreInteractions(listener); + } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. From 8ceb6a3a56c7e3fa198399f29113214078e81736 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 14:06:07 -0800 Subject: [PATCH 22/28] address comments --- .../spark/network/shuffle/RetryingBlockTransferor.java | 3 ++- .../spark/network/shuffle/RetryingBlockTransferorSuite.java | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 95ff7ca32cbd..61e26b07051f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -195,7 +195,8 @@ private synchronized void initiateRetry() { /** * Returns true if we should retry due a block transfer failure. We will retry if and only if - * the exception was an IOException and we haven't retried 'maxRetries' times already. + * the exception was an IOException or SaslTimeoutException and we haven't retried + * 'maxRetries' times already. */ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 6b7f5861af23..0ea070aedafd 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -26,11 +26,11 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeoutException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; -import java.util.concurrent.TimeoutException; -import org.apache.spark.network.sasl.SaslTimeoutException; + import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; @@ -43,6 +43,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.sasl.SaslTimeoutException; import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter; /** From 7b8a5690ce4578db36002a0158325b6f237d5596 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 17:53:58 -0800 Subject: [PATCH 23/28] checked for block transfer failure case --- .../shuffle/RetryingBlockTransferor.java | 20 +++++++++--- .../shuffle/RetryingBlockTransferorSuite.java | 31 ++++++++++++++++++- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 61e26b07051f..ebb4c40b4d7f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.util.Collections; import java.util.LinkedHashSet; @@ -86,7 +87,7 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) /** Number of times we've attempted to retry so far. */ private int retryCount = 0; - private boolean isCurrentSaslTimeout; + private boolean saslTimeoutSeen; /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. @@ -122,7 +123,7 @@ public RetryingBlockTransferor( this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; this.enableSaslRetries = conf.enableSaslRetries(); - this.isCurrentSaslTimeout = false; + this.saslTimeoutSeen = false; } public RetryingBlockTransferor( @@ -202,15 +203,24 @@ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; + if (!isSaslTimeout && saslTimeoutSeen) { + retryCount = 0; + saslTimeoutSeen = false; + } boolean hasRemainingRetries = retryCount < maxRetries; boolean shouldRetry = (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); if (shouldRetry && isSaslTimeout) { - this.isCurrentSaslTimeout = true; + this.saslTimeoutSeen = true; } return shouldRetry; } + @VisibleForTesting + public int getRetryCount() { + return retryCount; + } + /** * Our RetryListener intercepts block transfer responses and forwards them to our parent * listener. Note that in the event of a retry, we will immediately replace the 'currentListener' @@ -226,9 +236,9 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; - if (isCurrentSaslTimeout) { + if (saslTimeoutSeen) { retryCount = 0; - isCurrentSaslTimeout = false; + saslTimeoutSeen = false; } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 0ea070aedafd..675eabf046de 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -56,6 +56,7 @@ public class RetryingBlockTransferorSuite { private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); private static Map configMap; + private static RetryingBlockTransferor _retryingBlockTransferor; @Before public void initMap() { @@ -286,6 +287,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); verify(listener).getTransferType(); verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 0); } @Test @@ -307,6 +309,32 @@ public void testRepeatedSaslRetryFailures() throws IOException, InterruptedExcep verify(listener, times(3)).getTransferType(); verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 2); + } + + @Test + public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + IOException ioe = new IOException(); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .put("b1", ioe) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener).getTransferType(); + verify(listener).onBlockTransferSuccess("b1", block1); + assert(_retryingBlockTransferor.getRetryCount() == 0); } @@ -376,6 +404,7 @@ private static void performInteractions(List> inte assertNotNull(stub); stub.when(fetchStarter).createAndStart(any(), any()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); - new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start(); + _retryingBlockTransferor = new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener); + _retryingBlockTransferor.start(); } } From 8f17993310810cc9e2dbef064fc01369f7da9406 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 17:56:13 -0800 Subject: [PATCH 24/28] remove extra whitespace --- .../spark/network/shuffle/RetryingBlockTransferorSuite.java | 1 - 1 file changed, 1 deletion(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 675eabf046de..2c0ad1877d77 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -337,7 +337,6 @@ public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedE assert(_retryingBlockTransferor.getRetryCount() == 0); } - /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction From 5623b04c337b7d5b9d1baf8e8e0f8394ba45890b Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 13 Jan 2023 23:17:03 -0800 Subject: [PATCH 25/28] fix linter --- .../spark/network/shuffle/RetryingBlockTransferorSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 2c0ad1877d77..1c7f871bb422 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -403,7 +403,8 @@ private static void performInteractions(List> inte assertNotNull(stub); stub.when(fetchStarter).createAndStart(any(), any()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); - _retryingBlockTransferor = new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener); + _retryingBlockTransferor = + new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener); _retryingBlockTransferor.start(); } } From 207b3b6aa9dc911db465da64834a961d712e71a5 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Sat, 14 Jan 2023 14:29:13 -0800 Subject: [PATCH 26/28] fixed the test issues in retrying block transferor --- .../shuffle/RetryingBlockTransferorSuite.java | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 1c7f871bb422..a33a471fb7af 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -306,8 +306,8 @@ public void testRepeatedSaslRetryFailures() throws IOException, InterruptedExcep } configMap.put("spark.shuffle.sasl.enableRetries", "true"); performInteractions(interactions, listener); - verify(listener, times(3)).getTransferType(); verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener, times(3)).getTransferType(); verifyNoMoreInteractions(listener); assert(_retryingBlockTransferor.getRetryCount() == 2); } @@ -315,26 +315,30 @@ public void testRepeatedSaslRetryFailures() throws IOException, InterruptedExcep @Test public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - IOException ioe = new IOException(); - TimeoutException timeoutException = new TimeoutException(); - SaslTimeoutException saslTimeoutException = - new SaslTimeoutException(timeoutException); List> interactions = Arrays.asList( ImmutableMap.builder() - .put("b0", saslTimeoutException) - .put("b1", ioe) + .put("b0", new SaslTimeoutException(new TimeoutException())) + .put("b1", new IOException()) .build(), ImmutableMap.builder() .put("b0", block0) - .put("b1", block1) - .build() + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build() ); configMap.put("spark.shuffle.sasl.enableRetries", "true"); performInteractions(interactions, listener); - verify(listener).getTransferType(); - verify(listener).onBlockTransferSuccess("b1", block1); - assert(_retryingBlockTransferor.getRetryCount() == 0); + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1); + verify(listener, atLeastOnce()).getTransferType(); + verifyNoMoreInteractions(listener); + // This should be equal to 1 because after the SASL exception is retried, + // retryCount should be set back to 0. Then after that b1 encounters an + // exception that is retried. + assert(_retryingBlockTransferor.getRetryCount() == 1); } /** From 3bebbb64362917b73f1bf98fe41a0e45c2a2d560 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Sat, 14 Jan 2023 16:55:11 -0800 Subject: [PATCH 27/28] change import order --- .../spark/network/shuffle/RetryingBlockTransferor.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index ebb4c40b4d7f..94cf15985c9b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -17,7 +17,6 @@ package org.apache.spark.network.shuffle; -import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.util.Collections; import java.util.LinkedHashSet; @@ -27,13 +26,15 @@ import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; -import org.apache.spark.network.sasl.SaslTimeoutException; +import com.google.common.annotations.VisibleForTesting; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.sasl.SaslTimeoutException; /** * Wraps another BlockFetcher or BlockPusher with the ability to automatically retry block From 6fc2379b68e1b58f62372f6c258d4543831e3206 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Sat, 14 Jan 2023 17:23:08 -0800 Subject: [PATCH 28/28] change import ordering to be alphabetical --- .../spark/network/shuffle/RetryingBlockTransferor.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 94cf15985c9b..4515e3a5c282 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -24,17 +24,16 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; -import com.google.common.annotations.VisibleForTesting; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; -import org.apache.spark.network.sasl.SaslTimeoutException; /** * Wraps another BlockFetcher or BlockPusher with the ability to automatically retry block