diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..d58e3ce2b2c88 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; + +import com.google.common.base.Objects; + +import org.apache.spark.network.util.TransportConf; + +/** + * A {@link ManagedBuffer} backed by a segment in a file with digest. + */ +public final class DigestFileSegmentManagedBuffer extends FileSegmentManagedBuffer { + + private final long digest; + + public DigestFileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length, + long digest) { + super(conf, file, offset, length); + this.digest = digest; + } + + public long getDigest() { return digest; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", getFile()) + .add("offset", getOffset()) + .add("length", getLength()) + .add("digest", digest) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 66566b67870f3..ed64867ad6c12 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -38,7 +38,7 @@ /** * A {@link ManagedBuffer} backed by a segment in a file. */ -public final class FileSegmentManagedBuffer extends ManagedBuffer { +public class FileSegmentManagedBuffer extends ManagedBuffer { private final TransportConf conf; private final File file; private final long offset; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java index 519e6cb470d0d..db1682c086f58 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -36,6 +36,11 @@ public interface ChunkReceivedCallback { */ void onSuccess(int chunkIndex, ManagedBuffer buffer); + /** Called with a extra digest parameter upon receipt of a particular chunk.*/ + default void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) { + onSuccess(chunkIndex, buffer); + } + /** * Called upon failure to fetch a particular chunk. Note that this may actually be called due * to failure to fetch a prior chunk in this stream. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index d322aec28793e..2b9b6cce25622 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -35,6 +35,11 @@ public interface StreamCallback { /** Called when all data from the stream has been received. */ void onComplete(String streamId) throws IOException; + /** Called with a extra digest when all data from the stream has been received. */ + default void onComplete(String streamId, long digest) throws IOException { + onComplete(streamId); + } + /** Called if there's an error reading data from the stream. */ void onFailure(String streamId, Throwable cause) throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index f3eb744ff7345..f9d4ca1addcdc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -37,6 +37,7 @@ public class StreamInterceptor implements TransportFrameDecod private final long byteCount; private final StreamCallback callback; private long bytesRead; + private long digest = -1L; public StreamInterceptor( MessageHandler handler, @@ -50,6 +51,16 @@ public StreamInterceptor( this.bytesRead = 0; } + public StreamInterceptor( + MessageHandler handler, + String streamId, + long byteCount, + StreamCallback callback, + long digest) { + this(handler, streamId, byteCount, callback); + this.digest = digest; + } + @Override public void exceptionCaught(Throwable cause) throws Exception { deactivateStream(); @@ -86,7 +97,11 @@ public boolean handle(ByteBuf buf) throws Exception { throw re; } else if (bytesRead == byteCount) { deactivateStream(); - callback.onComplete(streamId); + if (digest < 0) { + callback.onComplete(streamId); + } else { + callback.onComplete(streamId, digest); + } } return bytesRead != byteCount; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2f143f77fa4ae..c7e648a71c317 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,6 +33,8 @@ import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.DigestChunkFetchSuccess; +import org.apache.spark.network.protocol.DigestStreamResponse; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -246,6 +248,45 @@ public void handle(ResponseMessage message) throws Exception { } else { logger.warn("Stream failure with unknown callback: {}", resp.error); } + } else if (message instanceof DigestChunkFetchSuccess) { + DigestChunkFetchSuccess resp = (DigestChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} since it is not outstanding", + resp.streamChunkId, getRemoteAddress(channel)); + resp.body().release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body(), resp.digest); + resp.body().release(); + } + } else if (message instanceof DigestStreamResponse) { + DigestStreamResponse resp = (DigestStreamResponse) message; + Pair entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry.getValue(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor( + this, resp.streamId, resp.byteCount, callback, resp.digest); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId, resp.digest); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } else { + logger.error("Could not find callback for StreamResponse."); + } } else { throw new IllegalStateException("Unknown response type: " + message.type()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java new file mode 100644 index 0000000000000..b7e326d9457cf --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when a chunk exists with a digest and has been + * successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class DigestChunkFetchSuccess extends AbstractResponseMessage { + public final StreamChunkId streamChunkId; + public final long digest; + + public DigestChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer, long digest) { + super(buffer, true); + this.streamChunkId = streamChunkId; + this.digest = digest; + } + + @Override + public Message.Type type() { return Type.DigestChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + buf.writeLong(digest); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new ChunkFetchFailure(streamChunkId, error); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static DigestChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + long digest = buf.readLong(); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new DigestChunkFetchSuccess(streamChunkId, managedBuf, digest); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, body(), digest); + } + + @Override + public boolean equals(Object other) { + if (other instanceof DigestChunkFetchSuccess) { + DigestChunkFetchSuccess o = (DigestChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && super.equals(o) && digest == o.digest; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("digest", digest) + .add("buffer", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java new file mode 100644 index 0000000000000..a184cba6654a5 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Response to {@link StreamRequest} with digest when the stream has been successfully opened. + *

+ * Note the message itself does not contain the stream data. That is written separately by the + * sender. The receiver is expected to set a temporary channel handler that will consume the + * number of bytes this message says the stream has. + */ +public final class DigestStreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; + public final long digest; + + public DigestStreamResponse(String streamId, long byteCount, ManagedBuffer buffer, long digest) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + this.digest = digest; + } + + @Override + public Message.Type type() { return Type.DigestStreamResponse; } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(streamId) + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + buf.writeLong(byteCount); + buf.writeLong(digest); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new StreamFailure(streamId, error); + } + + public static DigestStreamResponse decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + long byteCount = buf.readLong(); + long digest = buf.readLong(); + return new DigestStreamResponse(streamId, byteCount, null, digest); + } + + @Override + public int hashCode() { + return Objects.hashCode(byteCount, streamId, body(), digest); + } + + @Override + public boolean equals(Object other) { + if (other instanceof DigestStreamResponse) { + DigestStreamResponse o = (DigestStreamResponse) other; + return byteCount == o.byteCount && streamId.equals(o.streamId) && digest == o.digest; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("byteCount", byteCount) + .add("digest", digest) + .add("body", body()) + .toString(); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 0ccd70c03aba8..cd3efdc59d380 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,8 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), User(-1); + OneWayMessage(9), UploadStream(10), DigestChunkFetchSuccess(11), + DigestStreamResponse(12), User(-1); private final byte id; @@ -66,6 +67,8 @@ public static Type decode(ByteBuf buf) { case 8: return StreamFailure; case 9: return OneWayMessage; case 10: return UploadStream; + case 11: return DigestChunkFetchSuccess; + case 12: return DigestStreamResponse; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0afe10..0d98f0161015c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -50,6 +50,12 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { private Message decode(Message.Type msgType, ByteBuf in) { switch (msgType) { + case DigestChunkFetchSuccess: + return DigestChunkFetchSuccess.decode(in); + + case DigestStreamResponse: + return DigestStreamResponse.decode(in); + case ChunkFetchRequest: return ChunkFetchRequest.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index 82810dacdad84..f648a0fe6d0b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -25,15 +25,13 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; +import org.apache.spark.network.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; import static org.apache.spark.network.util.NettyUtils.*; @@ -111,8 +109,15 @@ public void processFetchRequest( } streamManager.chunkBeingSent(msg.streamChunkId.streamId); - respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( - (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + if (buf instanceof DigestFileSegmentManagedBuffer) { + respond(channel, new DigestChunkFetchSuccess(msg.streamChunkId, buf, + ((DigestFileSegmentManagedBuffer)buf).getDigest())).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } else { + respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } + } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index f178928006902..db7a43875cdcf 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -28,6 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.*; @@ -143,9 +144,16 @@ private void processStreamRequest(final StreamRequest req) { if (buf != null) { streamManager.streamBeingSent(req.streamId); - respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { - streamManager.streamSent(req.streamId); - }); + if (buf instanceof DigestFileSegmentManagedBuffer) { + respond(new DigestStreamResponse(req.streamId, buf.size(), buf, + ((DigestFileSegmentManagedBuffer) buf).getDigest())).addListener(future -> { + streamManager.streamSent(req.streamId); + }); + } else { + respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { + streamManager.streamSent(req.streamId); + }); + } } else { // org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated // when the following error message is changed. diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java new file mode 100644 index 0000000000000..7f19c95d64d74 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.InputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.zip.CRC32; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class DigestUtils { + private static final Logger LOG = LoggerFactory.getLogger(DigestUtils.class); + private static final int STREAM_BUFFER_LENGTH = 8192; + private static final int DIGEST_LENGTH = 8; + + public static int getDigestLength() { + return DIGEST_LENGTH; + } + + public static long getDigest(InputStream data) throws IOException { + CRC32 crc32 = new CRC32(); + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = data.read(buffer)) >= 0) { + crc32.update(buffer, 0, len); + } + return crc32.getValue(); + } + + public static long getDigest(File file, long offset, long length) { + if (length <= 0) { + return -1L; + } + try (RandomAccessFile rf = new RandomAccessFile(file, "r")) { + MappedByteBuffer data = rf.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, + length); + CRC32 crc32 = new CRC32(); + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len; + while ((len = Math.min(STREAM_BUFFER_LENGTH, data.remaining())) > 0) { + data.get(buffer, 0, len); + crc32.update(buffer, 0, len); + } + return crc32.getValue(); + } catch (IOException e) { + LOG.error(String.format("Exception while computing digest for file segment: " + + "%s(offset:%d, length:%d)", file.getName(), offset, length )); + return -1L; + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java new file mode 100644 index 0000000000000..ed4b9ead72038 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/util/DigestUtilsSuite.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.zip.CRC32; + +import org.junit.Test; + +public class DigestUtilsSuite { + + @Test + public void testGetDigest() throws IOException { + File testFile = File.createTempFile("test-digest", ".txt"); + try { + String testStr = "org.apache.spark.network.util.DigestUtilsSuite.testGetDigest"; + FileOutputStream out = new FileOutputStream(testFile); + out.write(testStr.getBytes()); + out.close(); + + CRC32 crc32 = new CRC32(); + crc32.update(testStr.getBytes()); + long digest = crc32.getValue(); + + assert(digest == DigestUtils.getDigest(new FileInputStream(testFile))); + assert(digest == DigestUtils.getDigest(testFile, 0, testFile.length())); + } finally { + testFile.delete(); + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 138fd5389c20a..b5f76aab11d3f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -29,6 +29,15 @@ public interface BlockFetchingListener extends EventListener { */ void onBlockFetchSuccess(String blockId, ManagedBuffer data); + /** + * Called once per successfully fetch block during shuffle, which has a parameter present the + * checkSum of shuffle block. Here provide a default method body for that not every + * blockFetchingListener need to implement one onBlockFetchSuccess method. + */ + default void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) { + onBlockFetchSuccess(blockId, data); + } + /** * Called at least once per block upon failures. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index ba1a17bf7e5ea..5deb7b47e8a42 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -44,9 +44,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.DigestUtils; import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; import org.apache.spark.network.util.JavaUtils; @@ -320,12 +322,25 @@ private ManagedBuffer getSortBasedShuffleBlockData( ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex( startReduceId, endReduceId); - return new FileSegmentManagedBuffer( - conf, - ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, + if (shuffleIndexInformation.isHasDigest()) { + File dataFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"); + return new DigestFileSegmentManagedBuffer( + conf, + dataFile, + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength(), + shuffleIndexRecord.getDigest().orElse(DigestUtils.getDigest( + dataFile, shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()))); + + } else { + return new FileSegmentManagedBuffer( + conf, + ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.data"), - shuffleIndexRecord.getOffset(), - shuffleIndexRecord.getLength()); + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength()); + } } catch (ExecutionException e) { throw new RuntimeException("Failed to open file: " + indexFile, e); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index ec2e3dce661d9..d7e1e096c423f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -165,6 +165,12 @@ public void onSuccess(int chunkIndex, ManagedBuffer buffer) { listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); } + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer, digest); + } + @Override public void onFailure(int chunkIndex, Throwable e) { // On receipt of a failure, fail every block from chunkIndex onwards. @@ -248,6 +254,14 @@ public void onComplete(String streamId) throws IOException { } } + @Override + public void onComplete(String streamId, long digest) throws IOException { + listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead(), digest); + if (!downloadFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + @Override public void onFailure(String streamId, Throwable cause) throws IOException { channel.close(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 6bf3da94030d4..dba15e0076dd5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -189,6 +189,11 @@ private synchronized boolean shouldRetry(Throwable e) { private class RetryingBlockFetchListener implements BlockFetchingListener { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + onBlockFetchSuccess(blockId, data, -1L); + } + + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) { // We will only forward this success message to our parent listener if this block request is // outstanding and we are still the active listener. boolean shouldForwardSuccess = false; @@ -201,7 +206,11 @@ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { // Now actually invoke the parent listener, outside of the synchronized block. if (shouldForwardSuccess) { - listener.onBlockFetchSuccess(blockId, data); + if (digest < 0) { + listener.onBlockFetchSuccess(blockId, data); + } else { + listener.onBlockFetchSuccess(blockId, data, digest); + } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index b65aacfcc4b9e..1badf5d41012a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -23,6 +23,7 @@ import java.nio.ByteBuffer; import java.nio.LongBuffer; import java.nio.file.Files; +import java.util.Optional; /** * Keeps the index information for a particular map output @@ -31,17 +32,43 @@ public class ShuffleIndexInformation { /** offsets as long buffer */ private final LongBuffer offsets; + private final boolean hasDigest; + /** digests as long buffer */ + private final LongBuffer digests; private int size; public ShuffleIndexInformation(File indexFile) throws IOException { size = (int)indexFile.length(); - ByteBuffer buffer = ByteBuffer.allocate(size); - offsets = buffer.asLongBuffer(); + ByteBuffer offsetsBuffer = ByteBuffer.allocate(size); + offsets = offsetsBuffer.asLongBuffer(); try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { - dis.readFully(buffer.array()); + dis.readFully(offsetsBuffer.array()); + } + + // This logic is from IndexShuffleBlockResolver, and the digest file name is from + // ShuffleIndexDigestBlockId. + File digestFile = new File(indexFile.getAbsolutePath() + ".digest"); + if (digestFile.exists()) { + hasDigest = true; + size += digestFile.length(); + ByteBuffer digestsBuffer = ByteBuffer.allocate((int)digestFile.length()); + digests = digestsBuffer.asLongBuffer(); + try (DataInputStream digIs = new DataInputStream(Files.newInputStream(digestFile.toPath()))) { + digIs.readFully(digestsBuffer.array()); + } + } else { + hasDigest = false; + digests = null; } } + /** + * If this indexFile has digest + */ + public boolean isHasDigest() { + return hasDigest; + } + /** * Size of the index file * @return size @@ -63,6 +90,13 @@ public ShuffleIndexRecord getIndex(int reduceId) { public ShuffleIndexRecord getIndex(int startReduceId, int endReduceId) { long offset = offsets.get(startReduceId); long nextOffset = offsets.get(endReduceId); - return new ShuffleIndexRecord(offset, nextOffset - offset); + if (hasDigest && endReduceId - startReduceId == 1) { + return new ShuffleIndexRecord(offset, nextOffset - offset, + Optional.of(digests.get(startReduceId))); + } else if (!hasDigest) { + return new ShuffleIndexRecord(offset, nextOffset - offset, Optional.of(-1L)); + } else { + return new ShuffleIndexRecord(offset, nextOffset - offset, Optional.empty()); + } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java index 6a4fac150a6bd..f64871c2bea3c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java @@ -17,16 +17,20 @@ package org.apache.spark.network.shuffle; +import java.util.Optional; + /** * Contains offset and length of the shuffle block data. */ public class ShuffleIndexRecord { private final long offset; private final long length; + private final Optional digest; - public ShuffleIndexRecord(long offset, long length) { + public ShuffleIndexRecord(long offset, long length, Optional digest) { this.offset = offset; this.length = length; + this.digest = digest; } public long getOffset() { @@ -36,5 +40,9 @@ public long getOffset() { public long getLength() { return length; } + + public Optional getDigest() { + return digest; + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index a6529fd76188a..0c760770d1cd3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -112,7 +112,7 @@ public long[] commitAllPartitions() throws IOException { } cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + blockResolver.writeIndexDigestFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); return partitionLengths; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index c8b41992a8919..09b9fabc5a5e5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -50,6 +50,6 @@ public void transferMapSpillFile( File outputFile = blockResolver.getDataFile(shuffleId, mapId); File tempFile = Utils.tempFileWith(outputFile); Files.move(mapSpillFile.toPath(), tempFile.toPath()); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + blockResolver.writeIndexDigestFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4cda4b180d97d..ab90b8228334e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1836,4 +1836,11 @@ package object config { .version("3.1.0") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_DIGEST_ENABLED = + ConfigBuilder("spark.shuffle.digest.enabled") + .internal() + .doc("The parameter to control whether check the digest of transmitted data during shuffle.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index bc2a0fbc36d5b..dd6bea1a1a7ad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -80,7 +80,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), readMetrics, - fetchContinuousBlocksInBatch).toCompletionIterator + fetchContinuousBlocksInBatch, + SparkEnv.get.conf.get(config.SHUFFLE_DIGEST_ENABLED)).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index af2c82e771970..c805a986bafe9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -22,11 +22,12 @@ import java.nio.channels.Channels import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{DigestFileSegmentManagedBuffer, FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExecutorDiskUtils +import org.apache.spark.network.util.{DigestUtils, LimitedInputStream} import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -52,6 +53,9 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + // The digest conf for shuffle block check + private final val digestEnable = conf.getBoolean(config.SHUFFLE_DIGEST_ENABLED.key, false); + private final val digestLength = DigestUtils.getDigestLength() def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None) @@ -84,6 +88,22 @@ private[spark] class IndexShuffleBlockResolver( .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } + /** + * Get the shuffle digest file. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + */ + private def getDigestFile( + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None): File = { + val blockId = ShuffleIndexDigestBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) + } + /** * Remove data file and index file that contain the output data from one map. */ @@ -101,6 +121,13 @@ private[spark] class IndexShuffleBlockResolver( logWarning(s"Error deleting index ${file.getPath()}") } } + + file = getDigestFile(shuffleId, mapId) + if (file.exists()) { + if (!file.delete()) { + logWarning(s"Error deleting digest ${file.getPath()}") + } + } } /** @@ -158,18 +185,43 @@ private[spark] class IndexShuffleBlockResolver( * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. */ - def writeIndexFileAndCommit( + def writeIndexDigestFileAndCommit( shuffleId: Int, mapId: Long, lengths: Array[Long], dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) + var digestFile: File = null + var digestTmp: File = null try { val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. synchronized { + val digests = new Array[Long](lengths.length) + val dateIn = if (dataTmp != null && dataTmp.exists()) { + new FileInputStream(dataTmp) + } else { + null + } + Utils.tryWithSafeFinally { + if (digestEnable && dateIn != null) { + for (i <- (0 until lengths.length)) { + val length = lengths(i) + if (length == 0) { + digests(i) = -1L + } else { + digests(i) = DigestUtils.getDigest(new LimitedInputStream(dateIn, length)) + } + } + } + } { + if (dateIn != null) { + dateIn.close() + } + } + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) if (existingLengths != null) { // Another attempt for the same task has already written our map outputs successfully, @@ -181,7 +233,7 @@ private[spark] class IndexShuffleBlockResolver( } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + var out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -194,15 +246,34 @@ private[spark] class IndexShuffleBlockResolver( out.close() } + if (digestEnable) { + digestFile = getDigestFile(shuffleId, mapId) + digestTmp = Utils.tempFileWith(digestFile) + out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(digestTmp))) + Utils.tryWithSafeFinally { + digests.foreach { digest => + out.writeLong(digest) + } + } { + out.close() + } + } + if (indexFile.exists()) { indexFile.delete() } + if (digestEnable && digestFile.exists()) { + digestFile.delete() + } if (dataFile.exists()) { dataFile.delete() } if (!indexTmp.renameTo(indexFile)) { throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) } + if (digestEnable && !digestTmp.renameTo(digestFile)) { + throw new IOException("fail to rename file " + digestTmp + " to " + digestFile) + } if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } @@ -212,6 +283,9 @@ private[spark] class IndexShuffleBlockResolver( if (indexTmp.exists() && !indexTmp.delete()) { logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") } + if (digestEnable && digestTmp.exists() && !digestFile.delete()) { + logError(s"Failed to delete temporary digest file at ${digestTmp.getAbsolutePath}") + } } } @@ -249,11 +323,36 @@ private[spark] class IndexShuffleBlockResolver( throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") } - new FileSegmentManagedBuffer( - transportConf, - getDataFile(shuffleId, mapId, dirs), - startOffset, - endOffset - startOffset) + + if (digestEnable) { + val digestValue = if (endReduceId - startReduceId == 1) { + val digestFile = getDigestFile(shuffleId, mapId, dirs) + val digestChannel = Files.newByteChannel(digestFile.toPath) + channel.position(startReduceId * 8L) + val digestIn = new DataInputStream(Channels.newInputStream(digestChannel)) + try { + digestIn.readLong() + } finally { + digestIn.close() + } + } else { + DigestUtils.getDigest(getDataFile(shuffleId, mapId, dirs), startOffset, + endOffset - startOffset) + } + + new DigestFileSegmentManagedBuffer( + transportConf, + getDataFile(shuffleId, mapId, dirs), + startOffset, + endOffset - startOffset, + digestValue) + } else { + new FileSegmentManagedBuffer( + transportConf, + getDataFile(shuffleId, mapId, dirs), + startOffset, + endOffset - startOffset) + } } finally { in.close() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 68ed3aa5b062f..453dc5373ad49 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -78,6 +78,11 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } +@DeveloperApi +case class ShuffleIndexDigestBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + override def name: String = ShuffleIndexBlockId(shuffleId, mapId, reduceId).name + ".digest" +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) @@ -119,6 +124,7 @@ object BlockId { val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r + val SHUFFLE_DIGEST = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).digest".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -137,6 +143,8 @@ object BlockId { ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) + case SHUFFLE_DIGEST(shuffleId, mapId, reduceId) => + ShuffleIndexDigestBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5efbc0703f729..a5ee93f503b0f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -32,7 +32,7 @@ import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ -import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.util.{DigestUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} @@ -80,7 +80,8 @@ final class ShuffleBlockFetcherIterator( detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, - doBatchFetch: Boolean) + doBatchFetch: Boolean, + digestEnabled: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -210,7 +211,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, address, _, buf, _) => + case SuccessFetchResult(_, _, address, _, buf, _, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { @@ -572,7 +573,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(fetchWaitTime) result match { - case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone, + digest) => if (address != blockManager.blockManagerId) { if (hostLocalBlocks.contains(blockId -> mapIndex)) { shuffleMetrics.incLocalBlocksFetched(1) @@ -611,7 +613,7 @@ final class ShuffleBlockFetcherIterator( throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) } - val in = try { + var in = try { buf.createInputStream() } catch { // The exception could only be throwed by local shuffle block @@ -626,16 +628,61 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, mapIndex, address, e) } + + // If shuffle digest enabled is true, check the block with checkSum. + var failedOnDigestCheck = false + if (digestEnabled) { + if (digest >= 0) { + val digestToCheck = try { + DigestUtils.getDigest(in) + } catch { + case e: IOException => + logError("Error occurs when checking digest", e) + buf.release() + throwFetchFailedException(blockId, mapIndex, address, e) + } + failedOnDigestCheck = digest != digestToCheck + if (failedOnDigestCheck) { + buf.release() + val e = new CheckDigestFailedException(s"The digest to check $digestToCheck " + + s"of $blockId is not equal with origin $digest") + if (corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else { + logError("The digest of read data is not correct and fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + in.close() + } + } + // If digest check passed, reset or recreate the inputStream + if (!failedOnDigestCheck) { + if (in.markSupported()) { + in.reset() + } else { + in.close() + in = buf.createInputStream() + } + } + } else { + logDebug(s"The digest for address: ${address.host} and blockID:" + + s"$blockId is null, local address is ${blockManager.blockManagerId.host}") + } + } try { - input = streamWrapper(blockId, in) - // If the stream is compressed or wrapped, then we optionally decompress/unwrap the - // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion - // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if - // the corruption is later, we'll still detect the corruption later in the stream. - streamCompressedOrEncrypted = !input.eq(in) - if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { - // TODO: manage the memory used here, and spill it into disk in case of OOM. - input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + if (!failedOnDigestCheck) { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (!digestEnabled && streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } } } catch { case e: IOException => @@ -969,6 +1016,7 @@ object ShuffleBlockFetcherIterator { * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. + * @param digest Is the digest of the result, default is -1L. */ private[storage] case class SuccessFetchResult( blockId: BlockId, @@ -976,7 +1024,8 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, size: Long, buf: ManagedBuffer, - isNetworkReqDone: Boolean) extends FetchResult { + isNetworkReqDone: Boolean, + digest: Long = -1L) extends FetchResult { require(buf != null) require(size >= 0) } @@ -994,4 +1043,12 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, e: Throwable) extends FetchResult + + /** + * An exception that the origin digest is not equal with the fetchResult's digest. + */ + private case class CheckDigestFailedException( + message: String, + cause: Throwable = null) + extends Exception(message, cause) } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index ee8e38c24b47f..cf64ce82e7417 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -152,11 +152,11 @@ public void setUp() throws IOException { doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); + .writeIndexDigestFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); + .writeIndexDigestFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 92ed24408384f..24e6b65783731 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -535,6 +535,7 @@ class CleanerTester( blockManager.master.getMatchingBlockIds( _ match { case ShuffleBlockId(`shuffleId`, _, _) => true case ShuffleIndexBlockId(`shuffleId`, _, _) => true + case ShuffleIndexDigestBlockId(`shuffleId`, _, _) => true case _ => false }, askSlaves = true) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 9e39271bdf9ee..d9392361780a4 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -418,6 +418,22 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC manager.unregisterShuffle(0) } + + test("SPARK-27562: Test shuffle with checking digest of transmitted data") { + conf.set(config.SHUFFLE_DIGEST_ENABLED, true) + val sc = new SparkContext("local", "test", conf) + val numRecords = 10000 + + val wordCount = sc.parallelize(1 to numRecords, 4) + .map(key => (key, 1)) + .reduceByKey(_ + _) + .collect() + val count = wordCount.length + val sum = wordCount.map(value => value._1).sum + assert(count == numRecords) + assert(sum == (1 to numRecords).sum) + sc.stop() + } } /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index f8474022867f4..98c3e59c8e92f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -76,7 +76,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) - when(blockResolver.writeIndexFileAndCommit( + when(blockResolver.writeIndexDigestFileAndCommit( anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 27bb06b4e0636..07c867a920e8c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -27,6 +27,9 @@ import org.mockito.invocation.InvocationOnMock import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer +import org.apache.spark.network.util.DigestUtils import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -71,7 +74,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + resolver.writeIndexDigestFileAndCommit(shuffleId, mapId, lengths, dataTmp) val indexFile = new File(tempDir.getAbsolutePath, idxName) val dataFile = resolver.getDataFile(shuffleId, mapId) @@ -91,7 +94,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + resolver.writeIndexDigestFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) @@ -130,7 +133,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + resolver.writeIndexDigestFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) @@ -155,4 +158,23 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa indexIn2.close() } } + + test("SPARK-27562: check the digest when shuffle digest enabled is true") { + val confClone = conf.clone + confClone.set(config.SHUFFLE_DIGEST_ENABLED, true) + val resolver = new IndexShuffleBlockResolver(confClone, blockManager) + val lengths = Array[Long](10, 0, 20) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + Utils.tryWithSafeFinally { + out.write(new Array[Byte](30)) + } { + out.close() + } + val digest = DigestUtils.getDigest(new ByteArrayInputStream(new Array[Byte](10))) + resolver.writeIndexDigestFileAndCommit(1, 2, lengths, dataTmp) + val managedBuffer = resolver.getBlockData(ShuffleBlockId(1, 2, 0)) + assert(managedBuffer.isInstanceOf[DigestFileSegmentManagedBuffer]) + assert(managedBuffer.asInstanceOf[DigestFileSegmentManagedBuffer].getDigest == digest) + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index f92455912f510..d55c8abbf4311 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -74,7 +74,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA .set("spark.app.id", "example.spark.app") .set("spark.shuffle.unsafe.file.output.buffer", "16k") when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile) - when(blockResolver.writeIndexFileAndCommit( + when(blockResolver.writeIndexDigestFileAndCommit( anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 43917a5b83bb0..9c4993fac0c68 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -175,6 +175,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) // 3 local blocks fetched in initialization @@ -251,6 +252,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) intercept[FetchFailedException] { iterator.next() } } @@ -285,6 +287,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so @@ -330,6 +333,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, + false, false) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will @@ -412,7 +416,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, - true) + true, + false) // 3 local blocks batch fetched in initialization verify(blockManager, times(1)).getLocalBlockData(any()) @@ -472,7 +477,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, - true) + true, + false) var numResults = 0 // After initialize(), there will be 6 FetchRequests. And each of the first 5 requests @@ -529,7 +535,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, metrics, - true) + true, + false) var numResults = 0 // After initialize(), there will be 2 FetchRequests. First one has 2 merged blocks and each // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from @@ -596,6 +603,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -666,6 +674,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // Continue only after the mock calls onBlockFetchFailure @@ -756,6 +765,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // Continue only after the mock calls onBlockFetchFailure @@ -827,6 +837,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // We'll get back the block which has corruption after maxBytesInFlight/3 because the other @@ -892,6 +903,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) val (id, st) = iterator.next() // Check that the test setup is correct -- make sure we have a concatenated stream. @@ -955,6 +967,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // Continue only after the mock calls onBlockFetchFailure @@ -1017,6 +1030,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT detectCorrupt = true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) } @@ -1066,6 +1080,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), + false, false) // All blocks fetched return zero length and should trigger a receive-side error: diff --git a/docs/configuration.md b/docs/configuration.md index fce04b940594b..e3a50a9cc82a8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -998,6 +998,13 @@ Apart from these, the following properties are also available, and may be useful 2.3.0 + + spark.shuffle.digest.enabled + false + + The parameter to control whether check the digest of transmitted data during shuffle. + + ### Spark UI