From 2138aa6d8225fb7c343c4a61dadc8c44e902e35b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 7 Jan 2018 15:21:33 -0800 Subject: [PATCH 1/4] Add position checks to IndexShuffleBlockResolver. --- .../shuffle/IndexShuffleBlockResolver.scala | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d0..266ee42e39cca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,8 +18,8 @@ package org.apache.spark.shuffle import java.io._ - -import com.google.common.io.ByteStreams +import java.nio.channels.Channels +import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code + // which is incorrectly using our file descriptor then this code will fetch the wrong offsets + // (which may cause a reducer to be sent a different reducer's data). The explicit position + // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this + // class of issue from re-occurring in the future which is why they are left here even though + // SPARK-22982 is fixed. + val channel = Files.newByteChannel(indexFile.toPath) + channel.position(blockId.reduceId * 8) + val in = new DataInputStream(Channels.newInputStream(channel)) try { - ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() + val actualPosition = channel.position() + val expectedPosition = blockId.reduceId * 8 + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + + s"expected $expectedPosition but actual position was $actualPosition.") + } new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), From 8e5ffa451f66dc64203dede68b9c0ac5fdc952cf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 7 Jan 2018 15:22:44 -0800 Subject: [PATCH 2/4] Remove unsafe asynchronous close() call from FileDownloadChannel --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f951591e02a5c..cdf2a7d57a821 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,15 +332,17 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) + var exceptionThrown = true try { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - } catch { - case e: Exception => + exceptionThrown = false + } finally { + if (exceptionThrown) { pipe.sink().close() source.close() - throw e + } } source @@ -376,18 +378,13 @@ private[netty] class NettyRpcEnv( def setError(e: Throwable): Unit = { error = e - source.close() } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + case _ if error != null => throw error case Success(bytesRead) => bytesRead - case Failure(readErr) => - if (error != null) { - throw error - } else { - throw readErr - } + case Failure(readErr) => throw readErr } } From c1278b80d1720d1e657db6ecf899ad25d2358a83 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 19:36:51 -0800 Subject: [PATCH 3/4] Use tryWithSafeFinallyAndFailureCallbacks --- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index cdf2a7d57a821..ad66d11d891ba 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,18 +332,14 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) - var exceptionThrown = true - try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - exceptionThrown = false - } finally { - if (exceptionThrown) { - pipe.sink().close() - source.close() - } - } + })(catchBlock = { + pipe.sink().close() + source.close() + }) source } From 9e5f20eef09fbd357c7cc8bb19eca7d4bf5f7170 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 9 Jan 2018 14:51:34 -0800 Subject: [PATCH 4/4] Add more comments. --- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index ad66d11d891ba..a2936d6ad539c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -368,16 +368,30 @@ private[netty] class NettyRpcEnv( fileDownloadFactory.createClient(host, port) } - private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { @volatile private var error: Throwable = _ def setError(e: Throwable): Unit = { + // This setError callback is invoked by internal RPC threads in order to propagate remote + // exceptions to application-level threads which are reading from this channel. When an + // RPC error occurs, the RPC system will call setError() and then will close the + // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe + // sink will cause `source.read()` operations to return EOF, unblocking the application-level + // reading thread. Thus there is no need to actually call `source.close()` here in the + // onError() callback and, in fact, calling it here would be dangerous because the close() + // would be asynchronous with respect to the read() call and could trigger race-conditions + // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. error = e } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + // See the documentation above in setError(): if an RPC error has occurred then setError() + // will be called to propagate the RPC error and then `source`'s corresponding + // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate + // the remote RPC exception (and not any exceptions triggered by the pipe close, such as + // ChannelClosedException), hence this `error != null` check: case _ if error != null => throw error case Success(bytesRead) => bytesRead case Failure(readErr) => throw readErr