From b1ff2575546e4b20261893718b446200649f0de9 Mon Sep 17 00:00:00 2001 From: turboFei Date: Fri, 12 Apr 2019 14:22:14 +0800 Subject: [PATCH 01/16] [SPARK-27562][Shuffle] Complete the verification mechanism for shuffle transmitted data --- .../DigestFileSegmentManagedBuffer.java | 50 ++++++++ .../buffer/FileSegmentManagedBuffer.java | 2 +- .../network/client/ChunkReceivedCallback.java | 5 + .../spark/network/client/StreamCallback.java | 5 + .../network/client/StreamInterceptor.java | 17 ++- .../client/TransportResponseHandler.java | 41 +++++++ .../protocol/DigestChunkFetchSuccess.java | 94 +++++++++++++++ .../protocol/DigestStreamResponse.java | 95 +++++++++++++++ .../spark/network/protocol/Message.java | 5 +- .../network/protocol/MessageDecoder.java | 6 + .../server/ChunkFetchRequestHandler.java | 17 ++- .../server/TransportRequestHandler.java | 14 ++- .../spark/network/util/DigestUtils.java | 64 ++++++++++ .../shuffle/BlockFetchingListener.java | 9 ++ .../shuffle/ExternalShuffleBlockResolver.java | 29 +++-- .../shuffle/OneForOneBlockFetcher.java | 14 +++ .../network/shuffle/RetryingBlockFetcher.java | 11 +- .../shuffle/ShuffleIndexInformation.java | 47 +++++++- .../network/shuffle/ShuffleIndexRecord.java | 10 +- .../spark/internal/config/package.scala | 7 ++ .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../shuffle/IndexShuffleBlockResolver.scala | 112 +++++++++++++++--- .../storage/ShuffleBlockFetcherIterator.scala | 89 +++++++++++--- .../scala/org/apache/spark/ShuffleSuite.scala | 16 +++ .../sort/IndexShuffleBlockResolverSuite.scala | 24 +++- .../ShuffleBlockFetcherIteratorSuite.scala | 21 +++- docs/configuration.md | 7 ++ 27 files changed, 752 insertions(+), 62 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..d58e3ce2b2c88 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; + +import com.google.common.base.Objects; + +import org.apache.spark.network.util.TransportConf; + +/** + * A {@link ManagedBuffer} backed by a segment in a file with digest. + */ +public final class DigestFileSegmentManagedBuffer extends FileSegmentManagedBuffer { + + private final long digest; + + public DigestFileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length, + long digest) { + super(conf, file, offset, length); + this.digest = digest; + } + + public long getDigest() { return digest; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", getFile()) + .add("offset", getOffset()) + .add("length", getLength()) + .add("digest", digest) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 66566b67870f3..ed64867ad6c12 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -38,7 +38,7 @@ /** * A {@link ManagedBuffer} backed by a segment in a file. */ -public final class FileSegmentManagedBuffer extends ManagedBuffer { +public class FileSegmentManagedBuffer extends ManagedBuffer { private final TransportConf conf; private final File file; private final long offset; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java index 519e6cb470d0d..db1682c086f58 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -36,6 +36,11 @@ public interface ChunkReceivedCallback { */ void onSuccess(int chunkIndex, ManagedBuffer buffer); + /** Called with a extra digest parameter upon receipt of a particular chunk.*/ + default void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) { + onSuccess(chunkIndex, buffer); + } + /** * Called upon failure to fetch a particular chunk. Note that this may actually be called due * to failure to fetch a prior chunk in this stream. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index d322aec28793e..2b9b6cce25622 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -35,6 +35,11 @@ public interface StreamCallback { /** Called when all data from the stream has been received. */ void onComplete(String streamId) throws IOException; + /** Called with a extra digest when all data from the stream has been received. */ + default void onComplete(String streamId, long digest) throws IOException { + onComplete(streamId); + } + /** Called if there's an error reading data from the stream. */ void onFailure(String streamId, Throwable cause) throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index f3eb744ff7345..f9d4ca1addcdc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -37,6 +37,7 @@ public class StreamInterceptor implements TransportFrameDecod private final long byteCount; private final StreamCallback callback; private long bytesRead; + private long digest = -1L; public StreamInterceptor( MessageHandler handler, @@ -50,6 +51,16 @@ public StreamInterceptor( this.bytesRead = 0; } + public StreamInterceptor( + MessageHandler handler, + String streamId, + long byteCount, + StreamCallback callback, + long digest) { + this(handler, streamId, byteCount, callback); + this.digest = digest; + } + @Override public void exceptionCaught(Throwable cause) throws Exception { deactivateStream(); @@ -86,7 +97,11 @@ public boolean handle(ByteBuf buf) throws Exception { throw re; } else if (bytesRead == byteCount) { deactivateStream(); - callback.onComplete(streamId); + if (digest < 0) { + callback.onComplete(streamId); + } else { + callback.onComplete(streamId, digest); + } } return bytesRead != byteCount; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2f143f77fa4ae..c7e648a71c317 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,6 +33,8 @@ import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.DigestChunkFetchSuccess; +import org.apache.spark.network.protocol.DigestStreamResponse; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -246,6 +248,45 @@ public void handle(ResponseMessage message) throws Exception { } else { logger.warn("Stream failure with unknown callback: {}", resp.error); } + } else if (message instanceof DigestChunkFetchSuccess) { + DigestChunkFetchSuccess resp = (DigestChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} since it is not outstanding", + resp.streamChunkId, getRemoteAddress(channel)); + resp.body().release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body(), resp.digest); + resp.body().release(); + } + } else if (message instanceof DigestStreamResponse) { + DigestStreamResponse resp = (DigestStreamResponse) message; + Pair entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry.getValue(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor( + this, resp.streamId, resp.byteCount, callback, resp.digest); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId, resp.digest); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } else { + logger.error("Could not find callback for StreamResponse."); + } } else { throw new IllegalStateException("Unknown response type: " + message.type()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java new file mode 100644 index 0000000000000..b7e326d9457cf --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when a chunk exists with a digest and has been + * successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class DigestChunkFetchSuccess extends AbstractResponseMessage { + public final StreamChunkId streamChunkId; + public final long digest; + + public DigestChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer, long digest) { + super(buffer, true); + this.streamChunkId = streamChunkId; + this.digest = digest; + } + + @Override + public Message.Type type() { return Type.DigestChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + buf.writeLong(digest); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new ChunkFetchFailure(streamChunkId, error); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static DigestChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + long digest = buf.readLong(); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new DigestChunkFetchSuccess(streamChunkId, managedBuf, digest); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, body(), digest); + } + + @Override + public boolean equals(Object other) { + if (other instanceof DigestChunkFetchSuccess) { + DigestChunkFetchSuccess o = (DigestChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && super.equals(o) && digest == o.digest; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("digest", digest) + .add("buffer", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java new file mode 100644 index 0000000000000..a184cba6654a5 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Response to {@link StreamRequest} with digest when the stream has been successfully opened. + *

+ * Note the message itself does not contain the stream data. That is written separately by the + * sender. The receiver is expected to set a temporary channel handler that will consume the + * number of bytes this message says the stream has. + */ +public final class DigestStreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; + public final long digest; + + public DigestStreamResponse(String streamId, long byteCount, ManagedBuffer buffer, long digest) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + this.digest = digest; + } + + @Override + public Message.Type type() { return Type.DigestStreamResponse; } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(streamId) + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + buf.writeLong(byteCount); + buf.writeLong(digest); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new StreamFailure(streamId, error); + } + + public static DigestStreamResponse decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + long byteCount = buf.readLong(); + long digest = buf.readLong(); + return new DigestStreamResponse(streamId, byteCount, null, digest); + } + + @Override + public int hashCode() { + return Objects.hashCode(byteCount, streamId, body(), digest); + } + + @Override + public boolean equals(Object other) { + if (other instanceof DigestStreamResponse) { + DigestStreamResponse o = (DigestStreamResponse) other; + return byteCount == o.byteCount && streamId.equals(o.streamId) && digest == o.digest; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("byteCount", byteCount) + .add("digest", digest) + .add("body", body()) + .toString(); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 0ccd70c03aba8..cd3efdc59d380 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,8 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), User(-1); + OneWayMessage(9), UploadStream(10), DigestChunkFetchSuccess(11), + DigestStreamResponse(12), User(-1); private final byte id; @@ -66,6 +67,8 @@ public static Type decode(ByteBuf buf) { case 8: return StreamFailure; case 9: return OneWayMessage; case 10: return UploadStream; + case 11: return DigestChunkFetchSuccess; + case 12: return DigestStreamResponse; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0afe10..0d98f0161015c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -50,6 +50,12 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { private Message decode(Message.Type msgType, ByteBuf in) { switch (msgType) { + case DigestChunkFetchSuccess: + return DigestChunkFetchSuccess.decode(in); + + case DigestStreamResponse: + return DigestStreamResponse.decode(in); + case ChunkFetchRequest: return ChunkFetchRequest.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index 82810dacdad84..f648a0fe6d0b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -25,15 +25,13 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; +import org.apache.spark.network.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; import static org.apache.spark.network.util.NettyUtils.*; @@ -111,8 +109,15 @@ public void processFetchRequest( } streamManager.chunkBeingSent(msg.streamChunkId.streamId); - respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( - (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + if (buf instanceof DigestFileSegmentManagedBuffer) { + respond(channel, new DigestChunkFetchSuccess(msg.streamChunkId, buf, + ((DigestFileSegmentManagedBuffer)buf).getDigest())).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } else { + respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } + } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index f178928006902..db7a43875cdcf 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -28,6 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.*; @@ -143,9 +144,16 @@ private void processStreamRequest(final StreamRequest req) { if (buf != null) { streamManager.streamBeingSent(req.streamId); - respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { - streamManager.streamSent(req.streamId); - }); + if (buf instanceof DigestFileSegmentManagedBuffer) { + respond(new DigestStreamResponse(req.streamId, buf.size(), buf, + ((DigestFileSegmentManagedBuffer) buf).getDigest())).addListener(future -> { + streamManager.streamSent(req.streamId); + }); + } else { + respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { + streamManager.streamSent(req.streamId); + }); + } } else { // org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated // when the following error message is changed. diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java new file mode 100644 index 0000000000000..a20d33d66a2f8 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.CRC32; + +public class DigestUtils { + private static final int STREAM_BUFFER_LENGTH = 8192; + private static final int DIGEST_LENGTH = 8; + + public static int getDigestLength() { + return DIGEST_LENGTH; + } + + public static long getDigest(InputStream data) throws IOException { + return updateCRC32(getCRC32(), data); + } + + public static long getDigest(File file, long offset, long length) { + if (length <= 0) { + return -1; + } + try { + LimitedInputStream inputStream = new LimitedInputStream(new FileInputStream(file), + offset + length, true); + inputStream.skip(offset); + return getDigest(inputStream); + } catch (IOException e) { + return -1; + } + } + + public static CRC32 getCRC32() { + return new CRC32(); + } + + public static long updateCRC32(CRC32 crc32, InputStream data) throws IOException { + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = data.read(buffer)) >= 0) { + crc32.update(buffer, 0, len); + } + return crc32.getValue(); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 138fd5389c20a..b5f76aab11d3f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -29,6 +29,15 @@ public interface BlockFetchingListener extends EventListener { */ void onBlockFetchSuccess(String blockId, ManagedBuffer data); + /** + * Called once per successfully fetch block during shuffle, which has a parameter present the + * checkSum of shuffle block. Here provide a default method body for that not every + * blockFetchingListener need to implement one onBlockFetchSuccess method. + */ + default void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) { + onBlockFetchSuccess(blockId, data); + } + /** * Called at least once per block upon failures. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index ba1a17bf7e5ea..2740373548d5a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -39,19 +39,17 @@ import com.google.common.cache.LoadingCache; import com.google.common.cache.Weigher; import com.google.common.collect.Maps; +import org.apache.spark.network.util.*; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; /** * Manages converting shuffle BlockIds into physical segments of local files, from a process outside @@ -320,12 +318,25 @@ private ManagedBuffer getSortBasedShuffleBlockData( ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex( startReduceId, endReduceId); - return new FileSegmentManagedBuffer( - conf, - ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, + if (shuffleIndexInformation.isHasDigest()) { + File dataFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"); + return new DigestFileSegmentManagedBuffer( + conf, + dataFile, + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength(), + shuffleIndexRecord.getDigest().orElse(DigestUtils.getDigest( + dataFile, shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()))); + + } else { + return new FileSegmentManagedBuffer( + conf, + ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.data"), - shuffleIndexRecord.getOffset(), - shuffleIndexRecord.getLength()); + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength()); + } } catch (ExecutionException e) { throw new RuntimeException("Failed to open file: " + indexFile, e); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index ec2e3dce661d9..d7e1e096c423f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -165,6 +165,12 @@ public void onSuccess(int chunkIndex, ManagedBuffer buffer) { listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); } + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer, digest); + } + @Override public void onFailure(int chunkIndex, Throwable e) { // On receipt of a failure, fail every block from chunkIndex onwards. @@ -248,6 +254,14 @@ public void onComplete(String streamId) throws IOException { } } + @Override + public void onComplete(String streamId, long digest) throws IOException { + listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead(), digest); + if (!downloadFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + @Override public void onFailure(String streamId, Throwable cause) throws IOException { channel.close(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 6bf3da94030d4..dba15e0076dd5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -189,6 +189,11 @@ private synchronized boolean shouldRetry(Throwable e) { private class RetryingBlockFetchListener implements BlockFetchingListener { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + onBlockFetchSuccess(blockId, data, -1L); + } + + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) { // We will only forward this success message to our parent listener if this block request is // outstanding and we are still the active listener. boolean shouldForwardSuccess = false; @@ -201,7 +206,11 @@ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { // Now actually invoke the parent listener, outside of the synchronized block. if (shouldForwardSuccess) { - listener.onBlockFetchSuccess(blockId, data); + if (digest < 0) { + listener.onBlockFetchSuccess(blockId, data); + } else { + listener.onBlockFetchSuccess(blockId, data, digest); + } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index b65aacfcc4b9e..8f2852ea02785 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -23,6 +23,9 @@ import java.nio.ByteBuffer; import java.nio.LongBuffer; import java.nio.file.Files; +import java.util.Optional; + +import org.apache.spark.network.util.DigestUtils; /** * Keeps the index information for a particular map output @@ -31,17 +34,44 @@ public class ShuffleIndexInformation { /** offsets as long buffer */ private final LongBuffer offsets; + private final boolean hasDigest; + /** digests as long buffer */ + private final LongBuffer digests; private int size; public ShuffleIndexInformation(File indexFile) throws IOException { + ByteBuffer offsetsBuffer, digestsBuffer; size = (int)indexFile.length(); - ByteBuffer buffer = ByteBuffer.allocate(size); - offsets = buffer.asLongBuffer(); + int offsetsSize, digestsSize; + if (size % 8 == 0) { + hasDigest = false; + offsetsSize = size; + digestsSize = 0; + } else { + hasDigest = true; + offsetsSize = ((size - 8 - 1) / (8 + DigestUtils.getDigestLength()) + 1) * 8; + digestsSize = size - offsetsSize -1; + } + offsetsBuffer = ByteBuffer.allocate(offsetsSize); + digestsBuffer = ByteBuffer.allocate(digestsSize); + offsets = offsetsBuffer.asLongBuffer(); + digests = digestsBuffer.asLongBuffer(); try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { - dis.readFully(buffer.array()); + dis.readFully(offsetsBuffer.array()); + if (hasDigest) { + dis.readByte(); + } + dis.readFully(digestsBuffer.array()); } } + /** + * If this indexFile has digest + */ + public boolean isHasDigest() { + return hasDigest; + } + /** * Size of the index file * @return size @@ -63,6 +93,15 @@ public ShuffleIndexRecord getIndex(int reduceId) { public ShuffleIndexRecord getIndex(int startReduceId, int endReduceId) { long offset = offsets.get(startReduceId); long nextOffset = offsets.get(endReduceId); - return new ShuffleIndexRecord(offset, nextOffset - offset); + /** Default digest is -1L.*/ + Optional digest = Optional.of(-1L); + if (hasDigest) { + if (endReduceId - startReduceId == 1) { + digest = Optional.of(digests.get(startReduceId)); + } else { + digest = Optional.empty(); + } + } + return new ShuffleIndexRecord(offset, nextOffset - offset, digest); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java index 6a4fac150a6bd..f64871c2bea3c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java @@ -17,16 +17,20 @@ package org.apache.spark.network.shuffle; +import java.util.Optional; + /** * Contains offset and length of the shuffle block data. */ public class ShuffleIndexRecord { private final long offset; private final long length; + private final Optional digest; - public ShuffleIndexRecord(long offset, long length) { + public ShuffleIndexRecord(long offset, long length, Optional digest) { this.offset = offset; this.length = length; + this.digest = digest; } public long getOffset() { @@ -36,5 +40,9 @@ public long getOffset() { public long getLength() { return length; } + + public Optional getDigest() { + return digest; + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4cda4b180d97d..61b52f797c2a6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1836,4 +1836,11 @@ package object config { .version("3.1.0") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_DIGEST_ENABLED = + ConfigBuilder("spark.shuffle.digest.enabled") + .internal() + .doc("The parameter to control whether check the transmitted data during shuffle.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index bc2a0fbc36d5b..dd6bea1a1a7ad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -80,7 +80,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), readMetrics, - fetchContinuousBlocksInBatch).toCompletionIterator + fetchContinuousBlocksInBatch, + SparkEnv.get.conf.get(config.SHUFFLE_DIGEST_ENABLED)).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index af2c82e771970..84e2f0bef8672 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -22,11 +22,12 @@ import java.nio.channels.Channels import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{DigestFileSegmentManagedBuffer, FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExecutorDiskUtils +import org.apache.spark.network.util.{DigestUtils, LimitedInputStream} import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -52,6 +53,9 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + // The digest conf for shuffle block check + private final val digestEnable = conf.getBoolean(config.SHUFFLE_DIGEST_ENABLED.key, false); + private final val digestLength = DigestUtils.getDigestLength() def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None) @@ -107,12 +111,16 @@ private[spark] class IndexShuffleBlockResolver( * Check whether the given index and data files match each other. * If so, return the partition lengths in the data file. Otherwise return null. */ - private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { - // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8L) { + private def checkIndexAndDataFile(index: File, data: File, blocks: Int, digests: Array[Long]): + (Array[Long], Array[Long]) = { + // Id digestEnable is false, the index file should have `blocks + 1` longs as offset. + // Otherwise, it should have a byte as flag, `blocks + 1` longs as offset and `blocks` digests + if ((!digestEnable && index.length() != (blocks + 1) * 8L) || + (digestEnable && index.length() != blocks * (8L + digestLength) + 8L + 1L)) { return null } val lengths = new Array[Long](blocks) + val digestArr = new Array[Long](blocks) // Read the lengths of blocks val in = try { new DataInputStream(new NioBufferedFileInputStream(index)) @@ -133,6 +141,18 @@ private[spark] class IndexShuffleBlockResolver( offset = off i += 1 } + if (digestEnable) { + val flag = in.readByte() + // the flag for digestEnable should be 1 + if (flag != 1) { + return null + } + i = 0 + while (i < blocks) { + digestArr(i) = in.readLong() + i += 1 + } + } } catch { case e: IOException => return null @@ -141,8 +161,8 @@ private[spark] class IndexShuffleBlockResolver( } // the size of data file should match with index file - if (data.length() == lengths.sum) { - lengths + if (data.length() == lengths.sum && !(0 until blocks).exists(i => digests(i) != digestArr(i))) { + (lengths, digestArr) } else { null } @@ -170,11 +190,38 @@ private[spark] class IndexShuffleBlockResolver( // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. synchronized { - val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) - if (existingLengths != null) { + val digests = new Array[Long](lengths.length) + val dateIn = if (dataTmp != null && dataTmp.exists()) { + new FileInputStream(dataTmp) + } else { + null + } + Utils.tryWithSafeFinally { + if (digestEnable && dateIn != null) { + for (i <- (0 until lengths.length)) { + val length = lengths(i) + if (length == 0) { + digests(i) = -1L + } else { + digests(i) = DigestUtils.getDigest(new LimitedInputStream(dateIn, length)) + } + } + } + } { + if (dateIn != null) { + dateIn.close() + } + } + + val existingLengthsDigests = + checkIndexAndDataFile(indexFile, dataFile, lengths.length, digests) + if (existingLengthsDigests != null) { + val existingLengths = existingLengthsDigests._1 + val existingDigests = existingLengthsDigests._2 // Another attempt for the same task has already written our map outputs successfully, // so just use the existing partition lengths and delete our temporary map outputs. System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + System.arraycopy(existingDigests, 0, digests, 0, digests.length) if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } @@ -190,6 +237,13 @@ private[spark] class IndexShuffleBlockResolver( offset += length out.writeLong(offset) } + if (digestEnable) { + // we write a byte present digest enable + out.writeByte(1) + for (digest <- digests) { + out.writeLong(digest) + } + } } { out.close() } @@ -237,6 +291,10 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) + var blocks = (indexFile.length() - 8) / 8 + if (digestEnable) { + blocks = (indexFile.length() - 8 - 1) / (8 + digestLength) + } channel.position(startReduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { @@ -249,11 +307,37 @@ private[spark] class IndexShuffleBlockResolver( throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") } - new FileSegmentManagedBuffer( - transportConf, - getDataFile(shuffleId, mapId, dirs), - startOffset, - endOffset - startOffset) + + if (digestEnable) { + val digestValue = if (endReduceId - startReduceId == 1) { + channel.position(1 + (blocks + 1) * 8L + startReduceId * digestLength) + val digest = in.readLong() + val actualDigestPosition = channel.position() + val expectedDigestLength = 1 + (blocks + 1) * 8L + (startReduceId + 1) * digestLength + if (actualDigestPosition != expectedDigestLength) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file " + + s"reads: expected $expectedDigestLength but actual position was " + + s" $actualDigestPosition.") + } + digest + } else { + DigestUtils.getDigest(getDataFile(shuffleId, mapId, dirs), startOffset, + endOffset - startOffset) + } + + new DigestFileSegmentManagedBuffer( + transportConf, + getDataFile(shuffleId, mapId, dirs), + startOffset, + endOffset - startOffset, + digestValue) + } else { + new FileSegmentManagedBuffer( + transportConf, + getDataFile(shuffleId, mapId, dirs), + startOffset, + endOffset - startOffset) + } } finally { in.close() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5efbc0703f729..27bc414b608dc 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -32,7 +32,7 @@ import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ -import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.util.{DigestUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} @@ -80,7 +80,8 @@ final class ShuffleBlockFetcherIterator( detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, - doBatchFetch: Boolean) + doBatchFetch: Boolean, + digestEnabled: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -210,7 +211,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, address, _, buf, _) => + case SuccessFetchResult(_, _, address, _, buf, _, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { @@ -572,7 +573,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(fetchWaitTime) result match { - case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone, + digest) => if (address != blockManager.blockManagerId) { if (hostLocalBlocks.contains(blockId -> mapIndex)) { shuffleMetrics.incLocalBlocksFetched(1) @@ -611,7 +613,7 @@ final class ShuffleBlockFetcherIterator( throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) } - val in = try { + var in = try { buf.createInputStream() } catch { // The exception could only be throwed by local shuffle block @@ -626,16 +628,60 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, mapIndex, address, e) } + + // If shuffle digest enabled is true, check the block with checkSum. + var failedOnDigestCheck = false + if (digestEnabled) { + if (digest >= 0) { + val digestToCheck = try { + DigestUtils.getDigest(in) + } catch { + case e: IOException => + logError("Error occurs when checking digest", e) + buf.release() + throwFetchFailedException(blockId, mapIndex, address, e) + } + failedOnDigestCheck = digest != digestToCheck + if (!failedOnDigestCheck) { + buf.release() + val e = new CheckDigestFailedException(s"The digest to check $digestToCheck " + + s"of $blockId is not equal with origin $digest") + if (corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else { + logError("The digest of read data is not correct and fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + in.close() + } + } + // If digest check passed, reset or recreate the inputStream + if (!failedOnDigestCheck) { + if (in.markSupported()) { + in.reset() + } else { + in = buf.createInputStream() + } + } + } else { + logDebug(s"The digest for address: ${address.host} and blockID:" + + s"$blockId is null, local address is ${blockManager.blockManagerId.host}") + } + } try { - input = streamWrapper(blockId, in) - // If the stream is compressed or wrapped, then we optionally decompress/unwrap the - // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion - // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if - // the corruption is later, we'll still detect the corruption later in the stream. - streamCompressedOrEncrypted = !input.eq(in) - if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { - // TODO: manage the memory used here, and spill it into disk in case of OOM. - input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + if (!failedOnDigestCheck) { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (!digestEnabled && streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } } } catch { case e: IOException => @@ -962,13 +1008,15 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block successfully. + * * @param blockId block id * @param mapIndex the mapIndex for this block, which indicate the index in the map stage. * @param address BlockManager that the block was fetched from. * @param size estimated size of the block. Note that this is NOT the exact bytes. * Size of remote block is used to calculate bytesInFlight. - * @param buf `ManagedBuffer` for the content. + * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. + * @param digest Is the digest of the result, default is -1L. */ private[storage] case class SuccessFetchResult( blockId: BlockId, @@ -976,7 +1024,8 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, size: Long, buf: ManagedBuffer, - isNetworkReqDone: Boolean) extends FetchResult { + isNetworkReqDone: Boolean, + digest: Long = -1L) extends FetchResult { require(buf != null) require(size >= 0) } @@ -994,4 +1043,12 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, e: Throwable) extends FetchResult + + /** + * An exception that the origin digest is not equal with the fetchResult's digest. + */ + private case class CheckDigestFailedException( + message: String, + cause: Throwable = null) + extends Exception(message, cause) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 9e39271bdf9ee..37e169c5eef02 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -418,6 +418,22 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC manager.unregisterShuffle(0) } + + test("[SPARK-27562]: test shuffle with shuffle digest enabled is true") { + conf.set(config.SHUFFLE_DIGEST_ENABLED, true) + val sc = new SparkContext("local", "test", conf) + val numRecords = 10000 + + val wordCount = sc.parallelize(1 to numRecords, 4) + .map(key => (key, 1)) + .reduceByKey(_ + _) + .collect() + val count = wordCount.length + val sum = wordCount.map(value => value._1).sum + assert(count == numRecords) + assert(sum == (1 to numRecords).sum) + sc.stop() + } } /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 27bb06b4e0636..c5ba21834cc8e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -27,6 +27,9 @@ import org.mockito.invocation.InvocationOnMock import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer +import org.apache.spark.network.util.DigestUtils import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -155,4 +158,23 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa indexIn2.close() } } + + test("[SPARK-27562]: check the digest when shuffle digest enabled is true") { + val confClone = conf.clone + confClone.set(config.SHUFFLE_DIGEST_ENABLED, true) + val resolver = new IndexShuffleBlockResolver(confClone, blockManager) + val lengths = Array[Long](10, 0, 20) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + Utils.tryWithSafeFinally { + out.write(new Array[Byte](30)) + } { + out.close() + } + val digest = DigestUtils.getDigest(new ByteArrayInputStream(new Array[Byte](10))) + resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + val managedBuffer = resolver.getBlockData(ShuffleBlockId(1, 2, 0)) + assert(managedBuffer.isInstanceOf[DigestFileSegmentManagedBuffer]) + assert(managedBuffer.asInstanceOf[DigestFileSegmentManagedBuffer].getDigest == digest) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 43917a5b83bb0..9c4993fac0c68 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -175,6 +175,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) // 3 local blocks fetched in initialization @@ -251,6 +252,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) intercept[FetchFailedException] { iterator.next() } } @@ -285,6 +287,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so @@ -330,6 +333,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will @@ -412,7 +416,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, - true) + true, + false) // 3 local blocks batch fetched in initialization verify(blockManager, times(1)).getLocalBlockData(any()) @@ -472,7 +477,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, - true) + true, + false) var numResults = 0 // After initialize(), there will be 6 FetchRequests. And each of the first 5 requests @@ -529,7 +535,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, - true) + true, + false) var numResults = 0 // After initialize(), there will be 2 FetchRequests. First one has 2 merged blocks and each // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from @@ -596,6 +603,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -666,6 +674,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // Continue only after the mock calls onBlockFetchFailure @@ -756,6 +765,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // Continue only after the mock calls onBlockFetchFailure @@ -827,6 +837,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // We'll get back the block which has corruption after maxBytesInFlight/3 because the other @@ -892,6 +903,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) val (id, st) = iterator.next() // Check that the test setup is correct -- make sure we have a concatenated stream. @@ -955,6 +967,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // Continue only after the mock calls onBlockFetchFailure @@ -1017,6 +1030,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT detectCorrupt = true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) } @@ -1066,6 +1080,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // All blocks fetched return zero length and should trigger a receive-side error: diff --git a/docs/configuration.md b/docs/configuration.md index fce04b940594b..3a91d4f7531ae 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -998,6 +998,13 @@ Apart from these, the following properties are also available, and may be useful 2.3.0 + + spark.shuffle.digest.enabled + false + + The parameter to control whether check the transmitted data during shuffle. + + ### Spark UI From d7bc0879e7fa8c15459ec798e62d69f7947e0c51 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 14 May 2020 11:44:32 +0800 Subject: [PATCH 02/16] refactor --- .../main/scala/org/apache/spark/internal/config/package.scala | 2 +- docs/configuration.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 61b52f797c2a6..ab90b8228334e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1840,7 +1840,7 @@ package object config { private[spark] val SHUFFLE_DIGEST_ENABLED = ConfigBuilder("spark.shuffle.digest.enabled") .internal() - .doc("The parameter to control whether check the transmitted data during shuffle.") + .doc("The parameter to control whether check the digest of transmitted data during shuffle.") .booleanConf .createWithDefault(false) } diff --git a/docs/configuration.md b/docs/configuration.md index 3a91d4f7531ae..e3a50a9cc82a8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1002,7 +1002,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.digest.enabled false - The parameter to control whether check the transmitted data during shuffle. + The parameter to control whether check the digest of transmitted data during shuffle. From be44d94b6feecab114cc039cf6a4336fabca2139 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 14 May 2020 11:58:26 +0800 Subject: [PATCH 03/16] save --- .../org/apache/spark/network/util/DigestUtils.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index a20d33d66a2f8..f69a1bf42653a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -23,7 +23,11 @@ import java.io.InputStream; import java.util.zip.CRC32; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class DigestUtils { + private static final Logger LOG = LoggerFactory.getLogger(DigestUtils.class); private static final int STREAM_BUFFER_LENGTH = 8192; private static final int DIGEST_LENGTH = 8; @@ -45,15 +49,17 @@ public static long getDigest(File file, long offset, long length) { inputStream.skip(offset); return getDigest(inputStream); } catch (IOException e) { + LOG.error(String.format("Exception while computing digest for file segment: " + + "%s(offset:%d, length:%d)", file.getName(), offset, length )); return -1; } } - public static CRC32 getCRC32() { + private static CRC32 getCRC32() { return new CRC32(); } - public static long updateCRC32(CRC32 crc32, InputStream data) throws IOException { + private static long updateCRC32(CRC32 crc32, InputStream data) throws IOException { byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; int len; while ((len = data.read(buffer)) >= 0) { From 9a6c2f41f81497f4bc59c765a6d5d2a8644d5d7b Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 14 May 2020 14:50:19 +0800 Subject: [PATCH 04/16] refactor --- .../spark/network/util/DigestUtils.java | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index f69a1bf42653a..99634a4725af6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -17,10 +17,10 @@ package org.apache.spark.network.util; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; import java.util.zip.CRC32; import org.slf4j.Logger; @@ -41,13 +41,18 @@ public static long getDigest(InputStream data) throws IOException { public static long getDigest(File file, long offset, long length) { if (length <= 0) { - return -1; + return -1L; } - try { - LimitedInputStream inputStream = new LimitedInputStream(new FileInputStream(file), - offset + length, true); - inputStream.skip(offset); - return getDigest(inputStream); + try (RandomAccessFile rf = new RandomAccessFile(file, "r")) { + MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, length); + CRC32 crc32 = getCRC32(); + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = Math.min(STREAM_BUFFER_LENGTH, data.remaining())) > 0) { + data.get(buffer, 0, len); + crc32.update(buffer, 0, len); + } + return crc32.getValue(); } catch (IOException e) { LOG.error(String.format("Exception while computing digest for file segment: " + "%s(offset:%d, length:%d)", file.getName(), offset, length )); From ae1c70969b29cf2b6918a37b65c88250126dc591 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 14 May 2020 16:11:15 +0800 Subject: [PATCH 05/16] add ut --- .../spark/network/util/DigestUtilsSuite.java | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java new file mode 100644 index 0000000000000..ed4b9ead72038 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.zip.CRC32; + +import org.junit.Test; + +public class DigestUtilsSuite { + + @Test + public void testGetDigest() throws IOException { + File testFile = File.createTempFile("test-digest", ".txt"); + try { + String testStr = "org.apache.spark.network.util.DigestUtilsSuite.testGetDigest"; + FileOutputStream out = new FileOutputStream(testFile); + out.write(testStr.getBytes()); + out.close(); + + CRC32 crc32 = new CRC32(); + crc32.update(testStr.getBytes()); + long digest = crc32.getValue(); + + assert(digest == DigestUtils.getDigest(new FileInputStream(testFile))); + assert(digest == DigestUtils.getDigest(testFile, 0, testFile.length())); + } finally { + testFile.delete(); + } + } +} From 30354aed806cd521d00edf5e32fc425043a47923 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 14 May 2020 16:19:49 +0800 Subject: [PATCH 06/16] save --- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 2 +- .../spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 37e169c5eef02..d9392361780a4 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -419,7 +419,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC manager.unregisterShuffle(0) } - test("[SPARK-27562]: test shuffle with shuffle digest enabled is true") { + test("SPARK-27562: Test shuffle with checking digest of transmitted data") { conf.set(config.SHUFFLE_DIGEST_ENABLED, true) val sc = new SparkContext("local", "test", conf) val numRecords = 10000 diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index c5ba21834cc8e..1af1e02d62cc4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -159,7 +159,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } } - test("[SPARK-27562]: check the digest when shuffle digest enabled is true") { + test("SPARK-27562: check the digest when shuffle digest enabled is true") { val confClone = conf.clone confClone.set(config.SHUFFLE_DIGEST_ENABLED, true) val resolver = new IndexShuffleBlockResolver(confClone, blockManager) From df226c43002ade3f4e8e6cac67da99eda3ee0cb0 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 14 May 2020 16:24:12 +0800 Subject: [PATCH 07/16] style --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 27bc414b608dc..e8d9c520855f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1008,13 +1008,12 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block successfully. - * * @param blockId block id * @param mapIndex the mapIndex for this block, which indicate the index in the map stage. * @param address BlockManager that the block was fetched from. * @param size estimated size of the block. Note that this is NOT the exact bytes. * Size of remote block is used to calculate bytesInFlight. - * @param buf `ManagedBuffer` for the content. + * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. * @param digest Is the digest of the result, default is -1L. */ From bb15a4dab44fdba03cf8fcab06c591f16ca5c4e6 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 15 May 2020 10:14:14 +0800 Subject: [PATCH 08/16] close before recreating --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e8d9c520855f0..4e28634931ee4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -662,6 +662,7 @@ final class ShuffleBlockFetcherIterator( if (in.markSupported()) { in.reset() } else { + in.close() in = buf.createInputStream() } } From b0fdea8273ea736ab35eb38aaf526cdfe349b9c1 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 15 May 2020 16:13:54 +0800 Subject: [PATCH 09/16] retest this please --- .../main/java/org/apache/spark/network/util/DigestUtils.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index 99634a4725af6..3b72770d6e710 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -56,7 +56,7 @@ public static long getDigest(File file, long offset, long length) { } catch (IOException e) { LOG.error(String.format("Exception while computing digest for file segment: " + "%s(offset:%d, length:%d)", file.getName(), offset, length )); - return -1; + return -1L; } } From 7ed97ec0547ddf530d3f323062041b4c5d907779 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 15 May 2020 17:23:02 +0800 Subject: [PATCH 10/16] retest this please --- .../main/java/org/apache/spark/network/util/DigestUtils.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index 3b72770d6e710..a2271e9df900f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -44,7 +44,8 @@ public static long getDigest(File file, long offset, long length) { return -1L; } try (RandomAccessFile rf = new RandomAccessFile(file, "r")) { - MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, length); + MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, + length); CRC32 crc32 = getCRC32(); byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; int len; From c893dc8ac28381847271d62cbff402b94f5221c5 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 15 May 2020 17:56:33 +0800 Subject: [PATCH 11/16] retest this please --- .../spark/network/util/DigestUtils.java | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index a2271e9df900f..3dcca960feeaf 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -36,7 +36,13 @@ public static int getDigestLength() { } public static long getDigest(InputStream data) throws IOException { - return updateCRC32(getCRC32(), data); + CRC32 crc32 = new CRC32(); + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = data.read(buffer)) >= 0) { + crc32.update(buffer, 0, len); + } + return crc32.getValue(); } public static long getDigest(File file, long offset, long length) { @@ -46,7 +52,7 @@ public static long getDigest(File file, long offset, long length) { try (RandomAccessFile rf = new RandomAccessFile(file, "r")) { MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, length); - CRC32 crc32 = getCRC32(); + CRC32 crc32 = new CRC32(); byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; int len; while ((len = Math.min(STREAM_BUFFER_LENGTH, data.remaining())) > 0) { @@ -60,17 +66,4 @@ public static long getDigest(File file, long offset, long length) { return -1L; } } - - private static CRC32 getCRC32() { - return new CRC32(); - } - - private static long updateCRC32(CRC32 crc32, InputStream data) throws IOException { - byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; - int len; - while ((len = data.read(buffer)) >= 0) { - crc32.update(buffer, 0, len); - } - return crc32.getValue(); - } } From c59cf83d943c25857d0a30a03c00a81e71ef508b Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 21 May 2020 15:42:21 +0800 Subject: [PATCH 12/16] address comments --- .../spark/network/util/DigestUtils.java | 74 ++++++++++--------- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index 3dcca960feeaf..7f19c95d64d74 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -17,8 +17,10 @@ package org.apache.spark.network.util; -import java.io.*; -import java.nio.ByteBuffer; +import java.io.File; +import java.io.InputStream; +import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.zip.CRC32; @@ -27,43 +29,43 @@ import org.slf4j.LoggerFactory; public class DigestUtils { - private static final Logger LOG = LoggerFactory.getLogger(DigestUtils.class); - private static final int STREAM_BUFFER_LENGTH = 8192; - private static final int DIGEST_LENGTH = 8; + private static final Logger LOG = LoggerFactory.getLogger(DigestUtils.class); + private static final int STREAM_BUFFER_LENGTH = 8192; + private static final int DIGEST_LENGTH = 8; - public static int getDigestLength() { - return DIGEST_LENGTH; - } + public static int getDigestLength() { + return DIGEST_LENGTH; + } - public static long getDigest(InputStream data) throws IOException { - CRC32 crc32 = new CRC32(); - byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; - int len; - while ((len = data.read(buffer)) >= 0) { - crc32.update(buffer, 0, len); - } - return crc32.getValue(); + public static long getDigest(InputStream data) throws IOException { + CRC32 crc32 = new CRC32(); + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = data.read(buffer)) >= 0) { + crc32.update(buffer, 0, len); } + return crc32.getValue(); + } - public static long getDigest(File file, long offset, long length) { - if (length <= 0) { - return -1L; - } - try (RandomAccessFile rf = new RandomAccessFile(file, "r")) { - MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, - length); - CRC32 crc32 = new CRC32(); - byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; - int len; - while ((len = Math.min(STREAM_BUFFER_LENGTH, data.remaining())) > 0) { - data.get(buffer, 0, len); - crc32.update(buffer, 0, len); - } - return crc32.getValue(); - } catch (IOException e) { - LOG.error(String.format("Exception while computing digest for file segment: " + - "%s(offset:%d, length:%d)", file.getName(), offset, length )); - return -1L; - } + public static long getDigest(File file, long offset, long length) { + if (length <= 0) { + return -1L; + } + try (RandomAccessFile rf = new RandomAccessFile(file, "r")) { + MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, + length); + CRC32 crc32 = new CRC32(); + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = Math.min(STREAM_BUFFER_LENGTH, data.remaining())) > 0) { + data.get(buffer, 0, len); + crc32.update(buffer, 0, len); + } + return crc32.getValue(); + } catch (IOException e) { + LOG.error(String.format("Exception while computing digest for file segment: " + + "%s(offset:%d, length:%d)", file.getName(), offset, length )); + return -1L; } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4e28634931ee4..a5ee93f503b0f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -642,7 +642,7 @@ final class ShuffleBlockFetcherIterator( throwFetchFailedException(blockId, mapIndex, address, e) } failedOnDigestCheck = digest != digestToCheck - if (!failedOnDigestCheck) { + if (failedOnDigestCheck) { buf.release() val e = new CheckDigestFailedException(s"The digest to check $digestToCheck " + s"of $blockId is not equal with origin $digest") From 96fd0f78c88a2bdd8ddda067b59eb67023abdf63 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 21 May 2020 17:43:15 +0800 Subject: [PATCH 13/16] split index and digests --- .../shuffle/ShuffleIndexInformation.java | 52 ++++---- .../io/LocalDiskShuffleMapOutputWriter.java | 2 +- .../LocalDiskSingleSpillMapOutputWriter.java | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 114 ++++++++++-------- .../org/apache/spark/storage/BlockId.scala | 8 ++ .../sort/UnsafeShuffleWriterSuite.java | 4 +- .../apache/spark/ContextCleanerSuite.scala | 1 + .../scala/org/apache/spark/ShuffleSuite.scala | 2 +- .../BypassMergeSortShuffleWriterSuite.scala | 2 +- .../sort/IndexShuffleBlockResolverSuite.scala | 8 +- ...LocalDiskShuffleMapOutputWriterSuite.scala | 2 +- 11 files changed, 108 insertions(+), 89 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 8f2852ea02785..fbd06512d0bce 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -25,8 +25,6 @@ import java.nio.file.Files; import java.util.Optional; -import org.apache.spark.network.util.DigestUtils; - /** * Keeps the index information for a particular map output * as an in-memory LongBuffer. @@ -40,28 +38,28 @@ public class ShuffleIndexInformation { private int size; public ShuffleIndexInformation(File indexFile) throws IOException { - ByteBuffer offsetsBuffer, digestsBuffer; size = (int)indexFile.length(); - int offsetsSize, digestsSize; - if (size % 8 == 0) { - hasDigest = false; - offsetsSize = size; - digestsSize = 0; - } else { - hasDigest = true; - offsetsSize = ((size - 8 - 1) / (8 + DigestUtils.getDigestLength()) + 1) * 8; - digestsSize = size - offsetsSize -1; - } - offsetsBuffer = ByteBuffer.allocate(offsetsSize); - digestsBuffer = ByteBuffer.allocate(digestsSize); + ByteBuffer offsetsBuffer = ByteBuffer.allocate(size); offsets = offsetsBuffer.asLongBuffer(); - digests = digestsBuffer.asLongBuffer(); try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { dis.readFully(offsetsBuffer.array()); - if (hasDigest) { - dis.readByte(); + } + /** + * This logic is from IndexShuffleBlockResolver, and the block id format is from + * ShuffleIndexDigestBlockId. + */ + File digestFile = new File(indexFile.getAbsolutePath() + ".digest"); + if (digestFile.exists()) { + hasDigest = true; + size += digestFile.length(); + ByteBuffer digestsBuffer = ByteBuffer.allocate((int)digestFile.length()); + digests = digestsBuffer.asLongBuffer(); + try (DataInputStream digIs = new DataInputStream(Files.newInputStream(digestFile.toPath()))) { + digIs.readFully(digestsBuffer.array()); } - dis.readFully(digestsBuffer.array()); + } else { + hasDigest = false; + digests = null; } } @@ -93,15 +91,13 @@ public ShuffleIndexRecord getIndex(int reduceId) { public ShuffleIndexRecord getIndex(int startReduceId, int endReduceId) { long offset = offsets.get(startReduceId); long nextOffset = offsets.get(endReduceId); - /** Default digest is -1L.*/ - Optional digest = Optional.of(-1L); - if (hasDigest) { - if (endReduceId - startReduceId == 1) { - digest = Optional.of(digests.get(startReduceId)); - } else { - digest = Optional.empty(); - } + if (hasDigest && endReduceId - startReduceId == 1) { + return new ShuffleIndexRecord(offset, nextOffset - offset, + Optional.of(digests.get(startReduceId))); + } else if (!hasDigest) { + return new ShuffleIndexRecord(offset, nextOffset - offset, Optional.of(-1L)); + } else { + return new ShuffleIndexRecord(offset, nextOffset - offset, Optional.empty()); } - return new ShuffleIndexRecord(offset, nextOffset - offset, digest); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index a6529fd76188a..0c760770d1cd3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -112,7 +112,7 @@ public long[] commitAllPartitions() throws IOException { } cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + blockResolver.writeIndexDigestFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); return partitionLengths; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index c8b41992a8919..09b9fabc5a5e5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -50,6 +50,6 @@ public void transferMapSpillFile( File outputFile = blockResolver.getDataFile(shuffleId, mapId); File tempFile = Utils.tempFileWith(outputFile); Files.move(mapSpillFile.toPath(), tempFile.toPath()); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + blockResolver.writeIndexDigestFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 84e2f0bef8672..e3cf79081c5da 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -88,6 +88,22 @@ private[spark] class IndexShuffleBlockResolver( .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } + /** + * Get the shuffle digest file. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + */ + private def getDigestFile( + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None): File = { + val blockId = ShuffleIndexDigestBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) + } + /** * Remove data file and index file that contain the output data from one map. */ @@ -105,22 +121,25 @@ private[spark] class IndexShuffleBlockResolver( logWarning(s"Error deleting index ${file.getPath()}") } } + + file = getDigestFile(shuffleId, mapId) + if (file.exists()) { + if (!file.delete()) { + logWarning(s"Error deleting digest ${file.getPath()}") + } + } } /** * Check whether the given index and data files match each other. * If so, return the partition lengths in the data file. Otherwise return null. */ - private def checkIndexAndDataFile(index: File, data: File, blocks: Int, digests: Array[Long]): - (Array[Long], Array[Long]) = { - // Id digestEnable is false, the index file should have `blocks + 1` longs as offset. - // Otherwise, it should have a byte as flag, `blocks + 1` longs as offset and `blocks` digests - if ((!digestEnable && index.length() != (blocks + 1) * 8L) || - (digestEnable && index.length() != blocks * (8L + digestLength) + 8L + 1L)) { + private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { + // the index file should have `block + 1` longs as offset. + if (index.length() != (blocks + 1) * 8L) { return null } val lengths = new Array[Long](blocks) - val digestArr = new Array[Long](blocks) // Read the lengths of blocks val in = try { new DataInputStream(new NioBufferedFileInputStream(index)) @@ -141,18 +160,6 @@ private[spark] class IndexShuffleBlockResolver( offset = off i += 1 } - if (digestEnable) { - val flag = in.readByte() - // the flag for digestEnable should be 1 - if (flag != 1) { - return null - } - i = 0 - while (i < blocks) { - digestArr(i) = in.readLong() - i += 1 - } - } } catch { case e: IOException => return null @@ -161,8 +168,8 @@ private[spark] class IndexShuffleBlockResolver( } // the size of data file should match with index file - if (data.length() == lengths.sum && !(0 until blocks).exists(i => digests(i) != digestArr(i))) { - (lengths, digestArr) + if (data.length() == lengths.sum) { + lengths } else { null } @@ -178,13 +185,15 @@ private[spark] class IndexShuffleBlockResolver( * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. */ - def writeIndexFileAndCommit( + def writeIndexDigestFileAndCommit( shuffleId: Int, mapId: Long, lengths: Array[Long], dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) + var digestFile: File = null + var digestTmp: File = null try { val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure @@ -213,22 +222,18 @@ private[spark] class IndexShuffleBlockResolver( } } - val existingLengthsDigests = - checkIndexAndDataFile(indexFile, dataFile, lengths.length, digests) - if (existingLengthsDigests != null) { - val existingLengths = existingLengthsDigests._1 - val existingDigests = existingLengthsDigests._2 + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { // Another attempt for the same task has already written our map outputs successfully, // so just use the existing partition lengths and delete our temporary map outputs. System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) - System.arraycopy(existingDigests, 0, digests, 0, digests.length) if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + var out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -237,13 +242,6 @@ private[spark] class IndexShuffleBlockResolver( offset += length out.writeLong(offset) } - if (digestEnable) { - // we write a byte present digest enable - out.writeByte(1) - for (digest <- digests) { - out.writeLong(digest) - } - } } { out.close() } @@ -260,12 +258,34 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } + + if (digestEnable) { + digestFile = getDigestFile(shuffleId, mapId) + digestTmp = Utils.tempFileWith(digestFile) + out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(digestTmp))) + Utils.tryWithSafeFinally { + digests.foreach { digest => + out.writeLong(digest) + } + } { + out.close() + } + if (digestFile.exists()) { + digestFile.delete() + } + if (!digestTmp.renameTo(digestFile)) { + throw new IOException("fail to rename file " + digestTmp + " to " + digestFile) + } + } } } } finally { if (indexTmp.exists() && !indexTmp.delete()) { logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") } + if (digestEnable && digestTmp.exists() && !digestFile.delete()) { + logError(s"Failed to delete temporary digest file at ${digestTmp.getAbsolutePath}") + } } } @@ -291,10 +311,6 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - var blocks = (indexFile.length() - 8) / 8 - if (digestEnable) { - blocks = (indexFile.length() - 8 - 1) / (8 + digestLength) - } channel.position(startReduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { @@ -310,21 +326,19 @@ private[spark] class IndexShuffleBlockResolver( if (digestEnable) { val digestValue = if (endReduceId - startReduceId == 1) { - channel.position(1 + (blocks + 1) * 8L + startReduceId * digestLength) - val digest = in.readLong() - val actualDigestPosition = channel.position() - val expectedDigestLength = 1 + (blocks + 1) * 8L + (startReduceId + 1) * digestLength - if (actualDigestPosition != expectedDigestLength) { - throw new Exception(s"SPARK-22982: Incorrect channel position after index file " + - s"reads: expected $expectedDigestLength but actual position was " + - s" $actualDigestPosition.") + val digestFile = getDigestFile(shuffleId, mapId, dirs) + val digestChannel = Files.newByteChannel(digestFile.toPath) + channel.position(startReduceId * 8L) + val digestIn = new DataInputStream(Channels.newInputStream(digestChannel)) + try { + digestIn.readLong() + } finally { + digestIn.close() } - digest } else { DigestUtils.getDigest(getDataFile(shuffleId, mapId, dirs), startOffset, endOffset - startOffset) } - new DigestFileSegmentManagedBuffer( transportConf, getDataFile(shuffleId, mapId, dirs), diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 68ed3aa5b062f..453dc5373ad49 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -78,6 +78,11 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } +@DeveloperApi +case class ShuffleIndexDigestBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + override def name: String = ShuffleIndexBlockId(shuffleId, mapId, reduceId).name + ".digest" +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) @@ -119,6 +124,7 @@ object BlockId { val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r + val SHUFFLE_DIGEST = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).digest".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -137,6 +143,8 @@ object BlockId { ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) + case SHUFFLE_DIGEST(shuffleId, mapId, reduceId) => + ShuffleIndexDigestBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index ee8e38c24b47f..cf64ce82e7417 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -152,11 +152,11 @@ public void setUp() throws IOException { doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); + .writeIndexDigestFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); + .writeIndexDigestFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 92ed24408384f..24e6b65783731 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -535,6 +535,7 @@ class CleanerTester( blockManager.master.getMatchingBlockIds( _ match { case ShuffleBlockId(`shuffleId`, _, _) => true case ShuffleIndexBlockId(`shuffleId`, _, _) => true + case ShuffleIndexDigestBlockId(`shuffleId`, _, _) => true case _ => false }, askSlaves = true) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d9392361780a4..31fce86a672ff 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter -import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} +import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexDigestBlockId, ShuffleIndexBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index f8474022867f4..98c3e59c8e92f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -76,7 +76,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) - when(blockResolver.writeIndexFileAndCommit( + when(blockResolver.writeIndexDigestFileAndCommit( anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 1af1e02d62cc4..07c867a920e8c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -74,7 +74,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + resolver.writeIndexDigestFileAndCommit(shuffleId, mapId, lengths, dataTmp) val indexFile = new File(tempDir.getAbsolutePath, idxName) val dataFile = resolver.getDataFile(shuffleId, mapId) @@ -94,7 +94,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + resolver.writeIndexDigestFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) @@ -133,7 +133,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + resolver.writeIndexDigestFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) @@ -172,7 +172,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa out.close() } val digest = DigestUtils.getDigest(new ByteArrayInputStream(new Array[Byte](10))) - resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + resolver.writeIndexDigestFileAndCommit(1, 2, lengths, dataTmp) val managedBuffer = resolver.getBlockData(ShuffleBlockId(1, 2, 0)) assert(managedBuffer.isInstanceOf[DigestFileSegmentManagedBuffer]) assert(managedBuffer.asInstanceOf[DigestFileSegmentManagedBuffer].getDigest == digest) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index f92455912f510..d55c8abbf4311 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -74,7 +74,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA .set("spark.app.id", "example.spark.app") .set("spark.shuffle.unsafe.file.output.buffer", "16k") when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile) - when(blockResolver.writeIndexFileAndCommit( + when(blockResolver.writeIndexDigestFileAndCommit( anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] From 09c498cbc3dcae8956f7150b23e206a7443000c5 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 21 May 2020 17:44:01 +0800 Subject: [PATCH 14/16] save --- .../network/shuffle/ExternalShuffleBlockResolver.java | 6 +++++- .../spark/network/shuffle/ShuffleIndexInformation.java | 7 +++---- .../apache/spark/shuffle/IndexShuffleBlockResolver.scala | 1 + 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 2740373548d5a..5deb7b47e8a42 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -39,7 +39,6 @@ import com.google.common.cache.LoadingCache; import com.google.common.cache.Weigher; import com.google.common.collect.Maps; -import org.apache.spark.network.util.*; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; import org.slf4j.Logger; @@ -49,7 +48,12 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.DigestUtils; +import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; /** * Manages converting shuffle BlockIds into physical segments of local files, from a process outside diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index fbd06512d0bce..1badf5d41012a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -44,10 +44,9 @@ public ShuffleIndexInformation(File indexFile) throws IOException { try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { dis.readFully(offsetsBuffer.array()); } - /** - * This logic is from IndexShuffleBlockResolver, and the block id format is from - * ShuffleIndexDigestBlockId. - */ + + // This logic is from IndexShuffleBlockResolver, and the digest file name is from + // ShuffleIndexDigestBlockId. File digestFile = new File(indexFile.getAbsolutePath() + ".digest"); if (digestFile.exists()) { hasDigest = true; diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index e3cf79081c5da..df9bc689cb530 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -339,6 +339,7 @@ private[spark] class IndexShuffleBlockResolver( DigestUtils.getDigest(getDataFile(shuffleId, mapId, dirs), startOffset, endOffset - startOffset) } + new DigestFileSegmentManagedBuffer( transportConf, getDataFile(shuffleId, mapId, dirs), From 2ccbbd4e585d960c7ce3f3055d4f17a16ce15af4 Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 21 May 2020 19:36:21 +0800 Subject: [PATCH 15/16] fix style --- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 31fce86a672ff..d9392361780a4 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter -import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexDigestBlockId, ShuffleIndexBlockId} +import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { From c59fcd627f5132bf8ebe493a8f1282139b31f1cf Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 21 May 2020 19:49:39 +0800 Subject: [PATCH 16/16] refactor --- .../shuffle/IndexShuffleBlockResolver.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index df9bc689cb530..c805a986bafe9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -246,19 +246,6 @@ private[spark] class IndexShuffleBlockResolver( out.close() } - if (indexFile.exists()) { - indexFile.delete() - } - if (dataFile.exists()) { - dataFile.delete() - } - if (!indexTmp.renameTo(indexFile)) { - throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) - } - if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { - throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) - } - if (digestEnable) { digestFile = getDigestFile(shuffleId, mapId) digestTmp = Utils.tempFileWith(digestFile) @@ -270,12 +257,25 @@ private[spark] class IndexShuffleBlockResolver( } { out.close() } - if (digestFile.exists()) { - digestFile.delete() - } - if (!digestTmp.renameTo(digestFile)) { - throw new IOException("fail to rename file " + digestTmp + " to " + digestFile) - } + } + + if (indexFile.exists()) { + indexFile.delete() + } + if (digestEnable && digestFile.exists()) { + digestFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (digestEnable && !digestTmp.renameTo(digestFile)) { + throw new IOException("fail to rename file " + digestTmp + " to " + digestFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } } }